import argparse import json from pathlib import Path from datasets import load_dataset DATASET_ID = "CaptionEmporium/furry-e621-sfw-7m-hq" SPLIT = "train" # Adjust these names if your actual columns differ. CAPTION_FIELDS = ["caption_llm_6", "caption_llm_8", "caption_cogvlm"] KEEP_FIELDS = ["tags_ground_truth_categorized"] + CAPTION_FIELDS def pick_id(row: dict) -> str: # Try a few common id keys; fall back to a hash-like stable string. for k in ("id", "post_id", "e621_id", "image_id"): if k in row and row[k] not in (None, ""): return str(row[k]) # As a fallback, derive a stable-ish id from caption text. base = (row.get("caption_llm_6") or row.get("caption_llm_8") or row.get("caption_cogvlm") or "") return f"no_id:{hash(base)}" def main() -> None: ap = argparse.ArgumentParser(description="Stream+shuffle sample and save a trimmed JSONL for prompt experiments.") ap.add_argument("--n", type=int, default=1000) ap.add_argument("--seed", type=int, default=123) ap.add_argument("--buffer-size", type=int, default=10_000) ap.add_argument( "--out", type=str, default="data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_trimmed.jsonl", ) ap.add_argument( "--require-any-caption", action="store_true", help="If set, only keep rows where at least one of the caption fields is non-empty.", ) args = ap.parse_args() out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) ds = load_dataset(DATASET_ID, split=SPLIT, streaming=True) ds = ds.shuffle(seed=args.seed, buffer_size=args.buffer_size) wrote = 0 with out_path.open("w", encoding="utf-8") as f: for row in ds: out = {"row_id": pick_id(row)} for k in KEEP_FIELDS: out[k] = row.get(k, "") if args.require_any_caption: if not any((out.get(c) or "").strip() for c in CAPTION_FIELDS): continue f.write(json.dumps(out, ensure_ascii=False) + "\n") wrote += 1 if wrote >= args.n: break print(f"Wrote {wrote} rows to: {out_path}") if __name__ == "__main__": main()