File size: 3,530 Bytes
14e5c38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Preprocess eval dataset: expand ground-truth tags through implication chains.



Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag

implication graph, removes known garbage tags, and writes a new JSONL with

an additional `tags_ground_truth_expanded` field (flat sorted list).



The original `tags_ground_truth_categorized` field is preserved unchanged.



Usage:

    python scripts/preprocess_eval_data.py



Input:  data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl

Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl

"""

from __future__ import annotations

import json
import sys
from pathlib import Path

# Add project root to path so we can import psq_rag
_REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(_REPO_ROOT))

from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications

# Tags that are annotation artifacts, not real content tags
GARBAGE_TAGS = frozenset({
    "invalid_tag",
    "invalid_background",
})

INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl")


def flatten_ground_truth(tags_categorized_str: str) -> set[str]:
    """Parse the categorized ground-truth JSON into a flat set of tags."""
    if not tags_categorized_str:
        return set()
    cats = json.loads(tags_categorized_str)
    tags = set()
    for tag_list in cats.values():
        if isinstance(tag_list, list):
            for t in tag_list:
                tags.add(t.strip())
    return tags


def main() -> int:
    if not INPUT_PATH.is_file():
        print(f"ERROR: Input not found: {INPUT_PATH}")
        return 1

    # Pre-warm implication graph
    impl = get_tag_implications()
    print(f"Loaded {sum(len(v) for v in impl.values())} active implications")

    samples_read = 0
    samples_expanded = 0
    total_tags_added = 0
    total_garbage_removed = 0

    with INPUT_PATH.open("r", encoding="utf-8") as fin, \
         OUTPUT_PATH.open("w", encoding="utf-8") as fout:
        for line in fin:
            row = json.loads(line)
            samples_read += 1

            gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", ""))

            # Remove garbage tags
            garbage_found = gt_raw & GARBAGE_TAGS
            if garbage_found:
                total_garbage_removed += len(garbage_found)
                gt_raw -= garbage_found

            # Expand through implications
            gt_expanded, implied_only = expand_tags_via_implications(gt_raw)
            if implied_only:
                samples_expanded += 1
                total_tags_added += len(implied_only)

            # Store expanded flat list alongside original categorized field
            row["tags_ground_truth_expanded"] = sorted(gt_expanded)

            fout.write(json.dumps(row, ensure_ascii=False) + "\n")

    print(f"Processed {samples_read} samples")
    print(f"  {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)")
    print(f"  {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)")
    print(f"  {total_garbage_removed} garbage tags removed")
    print(f"Output: {OUTPUT_PATH}")
    return 0


if __name__ == "__main__":
    sys.exit(main())