File size: 14,672 Bytes
080472b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390cebe
 
 
 
 
 
080472b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
"""Dataset Creator - Gradio app.

Lets you list any number of Hugging Face datasets, auto-detects (or lets
you manually map) how each one's rows turn into a chat-format triplet,
then combines, shuffles, and either pushes the result to your own HF
account or hands you a JSONL download.

UI wiring conventions used throughout this file, so the OAuth injection
keeps working:
  - Every per-entry callback takes `entries` (the gr.State list) and
    `uid` (a `gr.State(uid)` constant created at render time) as its
    first two parameters, matching the first two items in `inputs=`.
  - `oauth_token` / `oauth_profile` are declared as trailing, annotated
    parameters and are *never* included in `inputs=` - Gradio resolves
    them automatically from the signed-in session. See the Gradio OAuth
    guide; this is the one part of the app that needs a live Spaces
    OAuth session to fully smoke-test.
"""
from __future__ import annotations

import os
import random
import tempfile

import gradio as gr

import field_mapper
import schema_detect
import hf_dataset_loader
import hf_inspect
import hf_publish
from models import DatasetEntry, FieldMapping

_STATUS_LABELS = {
    "empty": "Not detected yet.",
    "detecting": "Checking schema...",
    "needs_mapping": "Couldn't auto-detect the field layout - map it manually below.",
    "ready": "Ready.",
    "error": "Error.",
}


def _find(entries: list, uid: str) -> DatasetEntry:
    for entry in entries:
        if entry.uid == uid:
            return entry
    raise ValueError(f"No dataset entry with uid {uid!r} - it may have been removed.")


def _status_text(entry: DatasetEntry) -> str:
    label = _STATUS_LABELS.get(entry.status, entry.status)
    if entry.status == "error" and entry.error_message:
        return f"**Status:** {label} {entry.error_message}"
    if entry.status == "ready" and entry.mapping:
        return f"**Status:** {label} (mapping type: `{entry.mapping.kind}`)"
    return f"**Status:** {label}"


# --- entry list mutation ----------------------------------------------------

def add_entry(entries: list) -> list:
    return entries + [DatasetEntry()]


def remove_entry(entries: list, uid: str) -> list:
    return [e for e in entries if e.uid != uid]


def update_field(entries: list, uid: str, field_name: str, value) -> list:
    entry = _find(entries, uid)
    setattr(entry, field_name, value)
    return entries


def update_limit(entries: list, uid: str, value) -> list:
    entry = _find(entries, uid)
    entry.limit = int(value) if value else 0
    return entries


# --- schema detection / mapping ---------------------------------------------

def detect_entry(entries: list, uid: str, oauth_token: gr.OAuthToken | None) -> list:
    entry = _find(entries, uid)
    if not entry.repo_id.strip():
        entry.status = "error"
        entry.error_message = "Enter a dataset repo id first."
        return entries

    entry.status = "detecting"
    token = oauth_token.token if oauth_token else None
    try:
        rows = hf_inspect.peek_rows(
            entry.repo_id.strip(),
            entry.subset.strip(),
            entry.split.strip() or "train",
            sample_size=8,
            token=token,
        )
    except hf_inspect.DatasetInspectError as exc:
        entry.status = "error"
        entry.error_message = str(exc)
        return entries

    entry.sample_rows = rows
    entry.detected_columns = list(rows[0].keys())
    entry.detected_list_info = schema_detect.detect_list_column(rows)
    mapping = schema_detect.auto_detect(rows)

    if mapping:
        entry.mapping = mapping
        entry.status = "ready"
        entry.error_message = ""
    else:
        entry.mapping = None
        entry.status = "needs_mapping"
        entry.error_message = ""
    return entries


def apply_flat_mapping(entries: list, uid: str, user_field: str, assistant_field: str) -> list:
    entry = _find(entries, uid)
    if not user_field or not assistant_field:
        entry.status = "error"
        entry.error_message = "Pick both a user field and an assistant field."
        return entries
    entry.mapping = FieldMapping(
        kind="flat_pair",
        config={"user_field": user_field, "assistant_field": assistant_field},
    )
    entry.status = "ready"
    entry.error_message = ""
    return entries


def apply_list_mapping(
    entries: list,
    uid: str,
    list_field: str,
    role_key: str,
    content_key: str,
    human_tag: str,
    assistant_tag: str,
) -> list:
    entry = _find(entries, uid)
    if not all([list_field, role_key, content_key, human_tag, assistant_tag]):
        entry.status = "error"
        entry.error_message = "Fill in every field for the conversation-list mapping."
        return entries
    if human_tag == assistant_tag:
        entry.status = "error"
        entry.error_message = "Human and assistant tags can't be the same value."
        return entries
    entry.mapping = FieldMapping(
        kind="conversation_list",
        config={
            "list_field": list_field,
            "role_key": role_key,
            "content_key": content_key,
            "human_tag": human_tag,
            "gpt_tag": assistant_tag,
        },
    )
    entry.status = "ready"
    entry.error_message = ""
    return entries


# --- combine / export --------------------------------------------------------

def run_pipeline(
    entries: list,
    seed,
    oauth_token: gr.OAuthToken | None,
    progress=gr.Progress(),
):
    ready_entries = [e for e in entries if e.status == "ready" and e.mapping is not None]
    if not ready_entries:
        raise gr.Error("No datasets are ready. Detect or manually map at least one dataset first.")

    token = oauth_token.token if oauth_token else None
    all_records = []
    breakdown_lines = []
    total = len(ready_entries)

    for idx, entry in enumerate(ready_entries):
        progress((idx, total), desc=f"Loading {entry.repo_id}")
        try:
            rows = hf_dataset_loader.load_limited(
                entry.repo_id.strip(),
                entry.subset.strip(),
                entry.split.strip() or "train",
                int(entry.limit),
                token=token,
            )
        except Exception as exc:
            raise gr.Error(f"Failed loading '{entry.repo_id}': {exc}")

        kept = 0
        for row in rows:
            triplet = field_mapper.extract_triplet(row, entry.mapping, entry.system_prompt)
            if triplet:
                all_records.append(triplet)
                kept += 1
        breakdown_lines.append(f"- `{entry.repo_id}` ({entry.split}): {kept}/{len(rows)} rows usable")

    progress((total, total), desc="Shuffling")
    rng = random.Random(int(seed)) if seed not in (None, "") else random.Random()
    rng.shuffle(all_records)

    summary = f"**Combined {len(all_records)} records from {total} dataset(s):**\n\n" + "\n".join(breakdown_lines)
    return all_records, summary


def do_push(
    records: list,
    repo_name: str,
    private: bool,
    oauth_token: gr.OAuthToken | None,
    oauth_profile: gr.OAuthProfile | None,
):
    if not records:
        raise gr.Error("Build the combined dataset first.")
    if not oauth_token or not oauth_profile:
        raise gr.Error("Sign in with your Hugging Face account first.")
    if not repo_name.strip():
        raise gr.Error("Give the dataset a name.")

    repo_id = f"{oauth_profile.username}/{repo_name.strip()}"
    try:
        url = hf_publish.push_dataset(records, repo_id, bool(private), oauth_token.token)
    except Exception as exc:
        raise gr.Error(f"Push failed: {exc}")
    return f"Pushed. View it at {url}"


def do_download(records: list):
    if not records:
        raise gr.Error("Build the combined dataset first.")
    out_dir = tempfile.mkdtemp(prefix="sage_dataset_")
    out_path = os.path.join(out_dir, "combined_dataset.jsonl")
    hf_publish.write_jsonl(records, out_path)
    return out_path


# --- UI -----------------------------------------------------------------------

with gr.Blocks(title="Dataset Creator") as demo:
    gr.Markdown("# Dataset Creator")
    gr.Markdown(
        "Combine chat-format data from multiple Hugging Face datasets into one "
        "shuffled set. Sign in to push the result to your own account, or just "
        "download it as JSONL."
    )
    gr.LoginButton()

    entries_state = gr.State([])
    combined_state = gr.State([])

    add_btn = gr.Button("+ Add dataset")

    @gr.render(inputs=entries_state)
    def render_entries(entries):
        for entry in entries:
            uid = entry.uid
            uid_state = gr.State(uid)

            with gr.Group():
                gr.Markdown(f"**{entry.repo_id or 'New dataset'}**", key=f"title-{uid}")

                with gr.Row():
                    repo_tb = gr.Textbox(
                        label="HF dataset repo",
                        value=entry.repo_id,
                        placeholder="e.g. NousResearch/hermes-function-calling-v1",
                        key=f"repo-{uid}",
                    )
                    subset_tb = gr.Textbox(
                        label="Config / subset (optional)", value=entry.subset, key=f"subset-{uid}"
                    )
                    split_tb = gr.Textbox(label="Split", value=entry.split or "train", key=f"split-{uid}")
                    limit_num = gr.Number(
                        label="Sample limit", value=entry.limit, precision=0, key=f"limit-{uid}"
                    )

                system_tb = gr.Textbox(
                    label="System prompt override",
                    value=entry.system_prompt,
                    lines=2,
                    key=f"sys-{uid}",
                )

                repo_field = gr.State("repo_id")
                subset_field = gr.State("subset")
                split_field = gr.State("split")
                sys_field = gr.State("system_prompt")

                repo_tb.change(update_field, [entries_state, uid_state, repo_field, repo_tb], entries_state)
                subset_tb.change(update_field, [entries_state, uid_state, subset_field, subset_tb], entries_state)
                split_tb.change(update_field, [entries_state, uid_state, split_field, split_tb], entries_state)
                limit_num.change(update_limit, [entries_state, uid_state, limit_num], entries_state)
                system_tb.change(update_field, [entries_state, uid_state, sys_field, system_tb], entries_state)

                with gr.Row():
                    detect_btn = gr.Button("Detect schema", key=f"detect-{uid}")
                    remove_btn = gr.Button("Remove", variant="stop", key=f"remove-{uid}")

                gr.Markdown(_status_text(entry), key=f"status-{uid}")

                detect_btn.click(detect_entry, [entries_state, uid_state], entries_state)
                remove_btn.click(remove_entry, [entries_state, uid_state], entries_state)

                if entry.status == "needs_mapping":
                    gr.Markdown(
                        f"Columns found: `{', '.join(entry.detected_columns)}`", key=f"cols-{uid}"
                    )
                    with gr.Row():
                        user_field_dd = gr.Dropdown(
                            choices=entry.detected_columns, label="User field", key=f"userf-{uid}"
                        )
                        asst_field_dd = gr.Dropdown(
                            choices=entry.detected_columns, label="Assistant field", key=f"asstf-{uid}"
                        )
                    apply_flat_btn = gr.Button("Apply flat mapping", key=f"applyflat-{uid}")
                    apply_flat_btn.click(
                        apply_flat_mapping,
                        [entries_state, uid_state, user_field_dd, asst_field_dd],
                        entries_state,
                    )

                    if entry.detected_list_info:
                        info = entry.detected_list_info
                        gr.Markdown(
                            f"Also found a conversation-style column: `{info['list_field']}` "
                            f"(role key `{info['role_key']}`, content key `{info['content_key']}`, "
                            f"tags seen: {info['tag_values']})",
                            key=f"listinfo-{uid}",
                        )
                        with gr.Row():
                            human_tag_dd = gr.Dropdown(
                                choices=info["tag_values"], label="Human/user tag", key=f"humant-{uid}"
                            )
                            asst_tag_dd = gr.Dropdown(
                                choices=info["tag_values"], label="Assistant tag", key=f"asstt-{uid}"
                            )
                        apply_list_btn = gr.Button(
                            "Apply conversation-list mapping", key=f"applylist-{uid}"
                        )

                        list_field_state = gr.State(info["list_field"])
                        role_key_state = gr.State(info["role_key"])
                        content_key_state = gr.State(info["content_key"])

                        apply_list_btn.click(
                            apply_list_mapping,
                            [
                                entries_state,
                                uid_state,
                                list_field_state,
                                role_key_state,
                                content_key_state,
                                human_tag_dd,
                                asst_tag_dd,
                            ],
                            entries_state,
                        )

    add_btn.click(add_entry, entries_state, entries_state)

    gr.Markdown("---")
    gr.Markdown("## Combine")
    seed_num = gr.Number(label="Shuffle seed (optional)", precision=0)
    build_btn = gr.Button("Build combined dataset", variant="primary")
    summary_md = gr.Markdown()
    build_btn.click(run_pipeline, [entries_state, seed_num], [combined_state, summary_md])

    gr.Markdown("## Export")
    with gr.Row():
        repo_name_tb = gr.Textbox(label="Dataset name (goes under your HF username)")
        private_cb = gr.Checkbox(label="Private", value=True)
    push_btn = gr.Button("Push to Hub")
    push_result_md = gr.Markdown()
    push_btn.click(do_push, [combined_state, repo_name_tb, private_cb], push_result_md)

    download_btn = gr.Button("Download as JSONL")
    download_file = gr.File(label="Download")
    download_btn.click(do_download, combined_state, download_file)

if __name__ == "__main__":
    demo.launch()