"""Push distilled JSONL files to the cuilabs/bee-interactions HF dataset. Why this exists --------------- The distillation pipeline (bee/distillation.py) writes JSONL to `data/datasets/distilled/.jsonl` with a teacher-student schema: { "instruction": ..., "input": ..., "output": ..., "domain": ..., "teacher_model": ..., "sample_id": ... } The Kaggle workers (workers/kaggle-{online,tpu}-train/train.py) load the HF dataset `cuilabs/bee-interactions` and apply `rejection_reason()` over each row. That function expects fields `prompt` / `content` / `domain` / `task_type` / `target_tiers` / `quality_score` / `feedback` / `role` — and rejects anything missing them. So distilled JSONL files would fail the filter even if uploaded as-is. This script bridges the two schemas and uploads the result as a single timestamped JSONL under `distilled/` in the dataset, which the trainer's `load_dataset(DATASET_ID, split="train")` call picks up automatically. Usage ----- HF_TOKEN=... python scripts/push_distilled_to_hf.py Optional flags: --input-dir default: data/datasets/distilled --dataset-id default: cuilabs/bee-interactions --quality-score <0..1> default: 0.85 (teacher-distilled is high) --target-tiers cell,cell-plus,comb --dry-run show what would be uploaded, don't push Idempotency ----------- The filename is timestamp-keyed and dropped under a `distilled/` prefix, so re-running this script appends new files rather than overwriting. The trainer dedupes via `dedupe_dataset()` (sha256 of prompt+content prefix), so identical rows from earlier uploads are skipped at training time. """ from __future__ import annotations import argparse import datetime import json import os import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_INPUT_DIR = REPO_ROOT / "data/datasets/distilled" DEFAULT_DATASET_ID = "cuilabs/bee-interactions" DEFAULT_QUALITY_SCORE = 0.85 DEFAULT_TARGET_TIERS = ["cell", "cell-plus", "comb"] def transform_row(distilled: dict, quality_score: float, target_tiers: list[str]) -> dict | None: """Map a distillation JSONL row to the trainer-expected schema. Returns None if the row is missing required fields. The trainer's `rejection_reason()` will reject any returned row that is too short or contains bad fragments — that's the point of having one filter on the training side rather than two. """ instruction = (distilled.get("instruction") or "").strip() extra_input = (distilled.get("input") or "").strip() output = (distilled.get("output") or "").strip() if not instruction or not output: return None prompt = instruction if not extra_input else f"{instruction}\n\n{extra_input}" domain = (distilled.get("domain") or "general").strip() or "general" return { # Trainer-required fields "role": "assistant", "prompt": prompt, "content": output, "domain": domain, "task_type": "general", "target_tiers": target_tiers, "quality_score": quality_score, "feedback": None, # Provenance "model_id": distilled.get("teacher_model"), "sample_id": distilled.get("sample_id"), "source": "distilled", "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(), } def collect_rows(input_dir: Path, quality_score: float, target_tiers: list[str]) -> list[dict]: rows: list[dict] = [] skipped = 0 files = sorted(input_dir.glob("*.jsonl")) if not files: print(f"[refuse] no .jsonl files under {input_dir}", file=sys.stderr) sys.exit(2) for path in files: # Skip non-source files: corrections (different schema, optional) # and prior upload-staging files (already in trainer schema). if path.name == "corrections.jsonl" or path.name.startswith("_upload-"): continue with path.open(encoding="utf-8") as f: for raw in f: raw = raw.strip() if not raw: continue try: distilled = json.loads(raw) except json.JSONDecodeError: skipped += 1 continue trainer_row = transform_row(distilled, quality_score, target_tiers) if trainer_row is None: skipped += 1 continue rows.append(trainer_row) print(f"[ok] read {path.name}: {sum(1 for r in rows if r['domain'] == path.stem)} usable rows so far") if skipped: print(f"[note] skipped {skipped} malformed/empty rows") return rows def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--input-dir", default=str(DEFAULT_INPUT_DIR)) parser.add_argument("--dataset-id", default=DEFAULT_DATASET_ID) parser.add_argument("--quality-score", type=float, default=DEFAULT_QUALITY_SCORE) parser.add_argument( "--target-tiers", default=",".join(DEFAULT_TARGET_TIERS), help="comma-separated tier names (default: cell,cell-plus,comb)", ) parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") if not hf_token and not args.dry_run: sys.exit("HF_TOKEN missing — export it before running (or use --dry-run).") input_dir = Path(args.input_dir) if not input_dir.exists(): sys.exit(f"input dir does not exist: {input_dir}") target_tiers = [t.strip() for t in args.target_tiers.split(",") if t.strip()] rows = collect_rows(input_dir, args.quality_score, target_tiers) if not rows: sys.exit("no usable rows after transform — check input JSONL files.") stamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S") out_path = REPO_ROOT / f"data/datasets/distilled/_upload-{stamp}.jsonl" out_path.parent.mkdir(parents=True, exist_ok=True) with out_path.open("w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row) + "\n") print(f"[ok] staged {len(rows)} rows at {out_path}") if args.dry_run: print("[dry-run] not uploading. First 3 transformed rows:") for row in rows[:3]: print(json.dumps(row, indent=2)[:500]) return # Lazy-import HF SDK so --dry-run works without it installed. try: from huggingface_hub import HfApi except ImportError: sys.exit("huggingface_hub not installed — `pip install huggingface_hub`") api = HfApi(token=hf_token) upload_path = f"distilled/distilled-{stamp}.jsonl" print(f"[push] uploading {len(rows)} rows → {args.dataset_id}:{upload_path}") api.upload_file( path_or_fileobj=str(out_path), path_in_repo=upload_path, repo_id=args.dataset_id, repo_type="dataset", commit_message=( f"distilled: {len(rows)} examples from " f"{sorted({r['domain'] for r in rows})} " f"(quality={args.quality_score})" ), ) print(f"[ok] uploaded https://huggingface.co/datasets/{args.dataset_id}/blob/main/{upload_path}") if __name__ == "__main__": main()