Spaces:
Running
Running
File size: 2,260 Bytes
c6be992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import argparse
import json
from pathlib import Path
from datasets import load_dataset
DATASET_ID = "CaptionEmporium/furry-e621-sfw-7m-hq"
SPLIT = "train"
# Adjust these names if your actual columns differ.
CAPTION_FIELDS = ["caption_llm_6", "caption_llm_8", "caption_cogvlm"]
KEEP_FIELDS = ["tags_ground_truth_categorized"] + CAPTION_FIELDS
def pick_id(row: dict) -> str:
# Try a few common id keys; fall back to a hash-like stable string.
for k in ("id", "post_id", "e621_id", "image_id"):
if k in row and row[k] not in (None, ""):
return str(row[k])
# As a fallback, derive a stable-ish id from caption text.
base = (row.get("caption_llm_6") or row.get("caption_llm_8") or row.get("caption_cogvlm") or "")
return f"no_id:{hash(base)}"
def main() -> None:
ap = argparse.ArgumentParser(description="Stream+shuffle sample and save a trimmed JSONL for prompt experiments.")
ap.add_argument("--n", type=int, default=1000)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--buffer-size", type=int, default=10_000)
ap.add_argument(
"--out",
type=str,
default="data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_trimmed.jsonl",
)
ap.add_argument(
"--require-any-caption",
action="store_true",
help="If set, only keep rows where at least one of the caption fields is non-empty.",
)
args = ap.parse_args()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
ds = load_dataset(DATASET_ID, split=SPLIT, streaming=True)
ds = ds.shuffle(seed=args.seed, buffer_size=args.buffer_size)
wrote = 0
with out_path.open("w", encoding="utf-8") as f:
for row in ds:
out = {"row_id": pick_id(row)}
for k in KEEP_FIELDS:
out[k] = row.get(k, "")
if args.require_any_caption:
if not any((out.get(c) or "").strip() for c in CAPTION_FIELDS):
continue
f.write(json.dumps(out, ensure_ascii=False) + "\n")
wrote += 1
if wrote >= args.n:
break
print(f"Wrote {wrote} rows to: {out_path}")
if __name__ == "__main__":
main()
|