Spaces:
Running
Running
| import argparse | |
| import json | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| DATASET_ID = "CaptionEmporium/furry-e621-sfw-7m-hq" | |
| SPLIT = "train" | |
| # Adjust these names if your actual columns differ. | |
| CAPTION_FIELDS = ["caption_llm_6", "caption_llm_8", "caption_cogvlm"] | |
| KEEP_FIELDS = ["tags_ground_truth_categorized"] + CAPTION_FIELDS | |
| def pick_id(row: dict) -> str: | |
| # Try a few common id keys; fall back to a hash-like stable string. | |
| for k in ("id", "post_id", "e621_id", "image_id"): | |
| if k in row and row[k] not in (None, ""): | |
| return str(row[k]) | |
| # As a fallback, derive a stable-ish id from caption text. | |
| base = (row.get("caption_llm_6") or row.get("caption_llm_8") or row.get("caption_cogvlm") or "") | |
| return f"no_id:{hash(base)}" | |
| def main() -> None: | |
| ap = argparse.ArgumentParser(description="Stream+shuffle sample and save a trimmed JSONL for prompt experiments.") | |
| ap.add_argument("--n", type=int, default=1000) | |
| ap.add_argument("--seed", type=int, default=123) | |
| ap.add_argument("--buffer-size", type=int, default=10_000) | |
| ap.add_argument( | |
| "--out", | |
| type=str, | |
| default="data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_trimmed.jsonl", | |
| ) | |
| ap.add_argument( | |
| "--require-any-caption", | |
| action="store_true", | |
| help="If set, only keep rows where at least one of the caption fields is non-empty.", | |
| ) | |
| args = ap.parse_args() | |
| out_path = Path(args.out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| ds = load_dataset(DATASET_ID, split=SPLIT, streaming=True) | |
| ds = ds.shuffle(seed=args.seed, buffer_size=args.buffer_size) | |
| wrote = 0 | |
| with out_path.open("w", encoding="utf-8") as f: | |
| for row in ds: | |
| out = {"row_id": pick_id(row)} | |
| for k in KEEP_FIELDS: | |
| out[k] = row.get(k, "") | |
| if args.require_any_caption: | |
| if not any((out.get(c) or "").strip() for c in CAPTION_FIELDS): | |
| continue | |
| f.write(json.dumps(out, ensure_ascii=False) + "\n") | |
| wrote += 1 | |
| if wrote >= args.n: | |
| break | |
| print(f"Wrote {wrote} rows to: {out_path}") | |
| if __name__ == "__main__": | |
| main() | |