bee / scripts /push_distilled_to_hf.py
Bee Deploy
HF Space backend deploy [de0cba5]
5e21013
"""Push distilled JSONL files to the cuilabs/bee-interactions HF dataset.
Why this exists
---------------
The distillation pipeline (bee/distillation.py) writes JSONL to
`data/datasets/distilled/<domain>.jsonl` with a teacher-student schema:
{ "instruction": ..., "input": ..., "output": ...,
"domain": ..., "teacher_model": ..., "sample_id": ... }
The Kaggle workers (workers/kaggle-{online,tpu}-train/train.py) load the
HF dataset `cuilabs/bee-interactions` and apply `rejection_reason()` over
each row. That function expects fields `prompt` / `content` / `domain` /
`task_type` / `target_tiers` / `quality_score` / `feedback` / `role` β€”
and rejects anything missing them. So distilled JSONL files would fail
the filter even if uploaded as-is.
This script bridges the two schemas and uploads the result as a single
timestamped JSONL under `distilled/` in the dataset, which the trainer's
`load_dataset(DATASET_ID, split="train")` call picks up automatically.
Usage
-----
HF_TOKEN=... python scripts/push_distilled_to_hf.py
Optional flags:
--input-dir <path> default: data/datasets/distilled
--dataset-id <repo> default: cuilabs/bee-interactions
--quality-score <0..1> default: 0.85 (teacher-distilled is high)
--target-tiers cell,cell-plus,comb
--dry-run show what would be uploaded, don't push
Idempotency
-----------
The filename is timestamp-keyed and dropped under a `distilled/` prefix,
so re-running this script appends new files rather than overwriting.
The trainer dedupes via `dedupe_dataset()` (sha256 of prompt+content
prefix), so identical rows from earlier uploads are skipped at training
time.
"""
from __future__ import annotations
import argparse
import datetime
import json
import os
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_INPUT_DIR = REPO_ROOT / "data/datasets/distilled"
DEFAULT_DATASET_ID = "cuilabs/bee-interactions"
DEFAULT_QUALITY_SCORE = 0.85
DEFAULT_TARGET_TIERS = ["cell", "cell-plus", "comb"]
def transform_row(distilled: dict, quality_score: float, target_tiers: list[str]) -> dict | None:
"""Map a distillation JSONL row to the trainer-expected schema.
Returns None if the row is missing required fields. The trainer's
`rejection_reason()` will reject any returned row that is too short
or contains bad fragments β€” that's the point of having one filter
on the training side rather than two.
"""
instruction = (distilled.get("instruction") or "").strip()
extra_input = (distilled.get("input") or "").strip()
output = (distilled.get("output") or "").strip()
if not instruction or not output:
return None
prompt = instruction if not extra_input else f"{instruction}\n\n{extra_input}"
domain = (distilled.get("domain") or "general").strip() or "general"
return {
# Trainer-required fields
"role": "assistant",
"prompt": prompt,
"content": output,
"domain": domain,
"task_type": "general",
"target_tiers": target_tiers,
"quality_score": quality_score,
"feedback": None,
# Provenance
"model_id": distilled.get("teacher_model"),
"sample_id": distilled.get("sample_id"),
"source": "distilled",
"created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
}
def collect_rows(input_dir: Path, quality_score: float, target_tiers: list[str]) -> list[dict]:
rows: list[dict] = []
skipped = 0
files = sorted(input_dir.glob("*.jsonl"))
if not files:
print(f"[refuse] no .jsonl files under {input_dir}", file=sys.stderr)
sys.exit(2)
for path in files:
# Skip non-source files: corrections (different schema, optional)
# and prior upload-staging files (already in trainer schema).
if path.name == "corrections.jsonl" or path.name.startswith("_upload-"):
continue
with path.open(encoding="utf-8") as f:
for raw in f:
raw = raw.strip()
if not raw:
continue
try:
distilled = json.loads(raw)
except json.JSONDecodeError:
skipped += 1
continue
trainer_row = transform_row(distilled, quality_score, target_tiers)
if trainer_row is None:
skipped += 1
continue
rows.append(trainer_row)
print(f"[ok] read {path.name}: {sum(1 for r in rows if r['domain'] == path.stem)} usable rows so far")
if skipped:
print(f"[note] skipped {skipped} malformed/empty rows")
return rows
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input-dir", default=str(DEFAULT_INPUT_DIR))
parser.add_argument("--dataset-id", default=DEFAULT_DATASET_ID)
parser.add_argument("--quality-score", type=float, default=DEFAULT_QUALITY_SCORE)
parser.add_argument(
"--target-tiers",
default=",".join(DEFAULT_TARGET_TIERS),
help="comma-separated tier names (default: cell,cell-plus,comb)",
)
parser.add_argument("--dry-run", action="store_true")
args = parser.parse_args()
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
if not hf_token and not args.dry_run:
sys.exit("HF_TOKEN missing β€” export it before running (or use --dry-run).")
input_dir = Path(args.input_dir)
if not input_dir.exists():
sys.exit(f"input dir does not exist: {input_dir}")
target_tiers = [t.strip() for t in args.target_tiers.split(",") if t.strip()]
rows = collect_rows(input_dir, args.quality_score, target_tiers)
if not rows:
sys.exit("no usable rows after transform β€” check input JSONL files.")
stamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
out_path = REPO_ROOT / f"data/datasets/distilled/_upload-{stamp}.jsonl"
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row) + "\n")
print(f"[ok] staged {len(rows)} rows at {out_path}")
if args.dry_run:
print("[dry-run] not uploading. First 3 transformed rows:")
for row in rows[:3]:
print(json.dumps(row, indent=2)[:500])
return
# Lazy-import HF SDK so --dry-run works without it installed.
try:
from huggingface_hub import HfApi
except ImportError:
sys.exit("huggingface_hub not installed β€” `pip install huggingface_hub`")
api = HfApi(token=hf_token)
upload_path = f"distilled/distilled-{stamp}.jsonl"
print(f"[push] uploading {len(rows)} rows β†’ {args.dataset_id}:{upload_path}")
api.upload_file(
path_or_fileobj=str(out_path),
path_in_repo=upload_path,
repo_id=args.dataset_id,
repo_type="dataset",
commit_message=(
f"distilled: {len(rows)} examples from "
f"{sorted({r['domain'] for r in rows})} "
f"(quality={args.quality_score})"
),
)
print(f"[ok] uploaded https://huggingface.co/datasets/{args.dataset_id}/blob/main/{upload_path}")
if __name__ == "__main__":
main()