Spaces:
Running
Running
File size: 3,530 Bytes
14e5c38 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """Preprocess eval dataset: expand ground-truth tags through implication chains.
Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag
implication graph, removes known garbage tags, and writes a new JSONL with
an additional `tags_ground_truth_expanded` field (flat sorted list).
The original `tags_ground_truth_categorized` field is preserved unchanged.
Usage:
python scripts/preprocess_eval_data.py
Input: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl
Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
# Add project root to path so we can import psq_rag
_REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications
# Tags that are annotation artifacts, not real content tags
GARBAGE_TAGS = frozenset({
"invalid_tag",
"invalid_background",
})
INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl")
def flatten_ground_truth(tags_categorized_str: str) -> set[str]:
"""Parse the categorized ground-truth JSON into a flat set of tags."""
if not tags_categorized_str:
return set()
cats = json.loads(tags_categorized_str)
tags = set()
for tag_list in cats.values():
if isinstance(tag_list, list):
for t in tag_list:
tags.add(t.strip())
return tags
def main() -> int:
if not INPUT_PATH.is_file():
print(f"ERROR: Input not found: {INPUT_PATH}")
return 1
# Pre-warm implication graph
impl = get_tag_implications()
print(f"Loaded {sum(len(v) for v in impl.values())} active implications")
samples_read = 0
samples_expanded = 0
total_tags_added = 0
total_garbage_removed = 0
with INPUT_PATH.open("r", encoding="utf-8") as fin, \
OUTPUT_PATH.open("w", encoding="utf-8") as fout:
for line in fin:
row = json.loads(line)
samples_read += 1
gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", ""))
# Remove garbage tags
garbage_found = gt_raw & GARBAGE_TAGS
if garbage_found:
total_garbage_removed += len(garbage_found)
gt_raw -= garbage_found
# Expand through implications
gt_expanded, implied_only = expand_tags_via_implications(gt_raw)
if implied_only:
samples_expanded += 1
total_tags_added += len(implied_only)
# Store expanded flat list alongside original categorized field
row["tags_ground_truth_expanded"] = sorted(gt_expanded)
fout.write(json.dumps(row, ensure_ascii=False) + "\n")
print(f"Processed {samples_read} samples")
print(f" {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)")
print(f" {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)")
print(f" {total_garbage_removed} garbage tags removed")
print(f"Output: {OUTPUT_PATH}")
return 0
if __name__ == "__main__":
sys.exit(main())
|