File size: 7,377 Bytes
5e21013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""Push distilled JSONL files to the cuilabs/bee-interactions HF dataset.

Why this exists
---------------
The distillation pipeline (bee/distillation.py) writes JSONL to
`data/datasets/distilled/<domain>.jsonl` with a teacher-student schema:

    { "instruction": ..., "input": ..., "output": ...,
      "domain": ..., "teacher_model": ..., "sample_id": ... }

The Kaggle workers (workers/kaggle-{online,tpu}-train/train.py) load the
HF dataset `cuilabs/bee-interactions` and apply `rejection_reason()` over
each row. That function expects fields `prompt` / `content` / `domain` /
`task_type` / `target_tiers` / `quality_score` / `feedback` / `role` β€”
and rejects anything missing them. So distilled JSONL files would fail
the filter even if uploaded as-is.

This script bridges the two schemas and uploads the result as a single
timestamped JSONL under `distilled/` in the dataset, which the trainer's
`load_dataset(DATASET_ID, split="train")` call picks up automatically.

Usage
-----
    HF_TOKEN=... python scripts/push_distilled_to_hf.py

Optional flags:
    --input-dir <path>      default: data/datasets/distilled
    --dataset-id <repo>     default: cuilabs/bee-interactions
    --quality-score <0..1>  default: 0.85 (teacher-distilled is high)
    --target-tiers cell,cell-plus,comb
    --dry-run               show what would be uploaded, don't push

Idempotency
-----------
The filename is timestamp-keyed and dropped under a `distilled/` prefix,
so re-running this script appends new files rather than overwriting.
The trainer dedupes via `dedupe_dataset()` (sha256 of prompt+content
prefix), so identical rows from earlier uploads are skipped at training
time.
"""
from __future__ import annotations

import argparse
import datetime
import json
import os
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_INPUT_DIR = REPO_ROOT / "data/datasets/distilled"
DEFAULT_DATASET_ID = "cuilabs/bee-interactions"
DEFAULT_QUALITY_SCORE = 0.85
DEFAULT_TARGET_TIERS = ["cell", "cell-plus", "comb"]


def transform_row(distilled: dict, quality_score: float, target_tiers: list[str]) -> dict | None:
    """Map a distillation JSONL row to the trainer-expected schema.

    Returns None if the row is missing required fields. The trainer's
    `rejection_reason()` will reject any returned row that is too short
    or contains bad fragments β€” that's the point of having one filter
    on the training side rather than two.
    """
    instruction = (distilled.get("instruction") or "").strip()
    extra_input = (distilled.get("input") or "").strip()
    output = (distilled.get("output") or "").strip()
    if not instruction or not output:
        return None

    prompt = instruction if not extra_input else f"{instruction}\n\n{extra_input}"
    domain = (distilled.get("domain") or "general").strip() or "general"

    return {
        # Trainer-required fields
        "role": "assistant",
        "prompt": prompt,
        "content": output,
        "domain": domain,
        "task_type": "general",
        "target_tiers": target_tiers,
        "quality_score": quality_score,
        "feedback": None,
        # Provenance
        "model_id": distilled.get("teacher_model"),
        "sample_id": distilled.get("sample_id"),
        "source": "distilled",
        "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
    }


def collect_rows(input_dir: Path, quality_score: float, target_tiers: list[str]) -> list[dict]:
    rows: list[dict] = []
    skipped = 0
    files = sorted(input_dir.glob("*.jsonl"))
    if not files:
        print(f"[refuse] no .jsonl files under {input_dir}", file=sys.stderr)
        sys.exit(2)
    for path in files:
        # Skip non-source files: corrections (different schema, optional)
        # and prior upload-staging files (already in trainer schema).
        if path.name == "corrections.jsonl" or path.name.startswith("_upload-"):
            continue
        with path.open(encoding="utf-8") as f:
            for raw in f:
                raw = raw.strip()
                if not raw:
                    continue
                try:
                    distilled = json.loads(raw)
                except json.JSONDecodeError:
                    skipped += 1
                    continue
                trainer_row = transform_row(distilled, quality_score, target_tiers)
                if trainer_row is None:
                    skipped += 1
                    continue
                rows.append(trainer_row)
        print(f"[ok] read {path.name}: {sum(1 for r in rows if r['domain'] == path.stem)} usable rows so far")
    if skipped:
        print(f"[note] skipped {skipped} malformed/empty rows")
    return rows


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--input-dir", default=str(DEFAULT_INPUT_DIR))
    parser.add_argument("--dataset-id", default=DEFAULT_DATASET_ID)
    parser.add_argument("--quality-score", type=float, default=DEFAULT_QUALITY_SCORE)
    parser.add_argument(
        "--target-tiers",
        default=",".join(DEFAULT_TARGET_TIERS),
        help="comma-separated tier names (default: cell,cell-plus,comb)",
    )
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
    if not hf_token and not args.dry_run:
        sys.exit("HF_TOKEN missing β€” export it before running (or use --dry-run).")

    input_dir = Path(args.input_dir)
    if not input_dir.exists():
        sys.exit(f"input dir does not exist: {input_dir}")

    target_tiers = [t.strip() for t in args.target_tiers.split(",") if t.strip()]
    rows = collect_rows(input_dir, args.quality_score, target_tiers)
    if not rows:
        sys.exit("no usable rows after transform β€” check input JSONL files.")

    stamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
    out_path = REPO_ROOT / f"data/datasets/distilled/_upload-{stamp}.jsonl"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with out_path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row) + "\n")
    print(f"[ok] staged {len(rows)} rows at {out_path}")

    if args.dry_run:
        print("[dry-run] not uploading. First 3 transformed rows:")
        for row in rows[:3]:
            print(json.dumps(row, indent=2)[:500])
        return

    # Lazy-import HF SDK so --dry-run works without it installed.
    try:
        from huggingface_hub import HfApi
    except ImportError:
        sys.exit("huggingface_hub not installed β€” `pip install huggingface_hub`")

    api = HfApi(token=hf_token)
    upload_path = f"distilled/distilled-{stamp}.jsonl"
    print(f"[push] uploading {len(rows)} rows β†’ {args.dataset_id}:{upload_path}")
    api.upload_file(
        path_or_fileobj=str(out_path),
        path_in_repo=upload_path,
        repo_id=args.dataset_id,
        repo_type="dataset",
        commit_message=(
            f"distilled: {len(rows)} examples from "
            f"{sorted({r['domain'] for r in rows})} "
            f"(quality={args.quality_score})"
        ),
    )
    print(f"[ok] uploaded https://huggingface.co/datasets/{args.dataset_id}/blob/main/{upload_path}")


if __name__ == "__main__":
    main()