"""Preprocess eval dataset: expand ground-truth tags through implication chains. Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag implication graph, removes known garbage tags, and writes a new JSONL with an additional `tags_ground_truth_expanded` field (flat sorted list). The original `tags_ground_truth_categorized` field is preserved unchanged. Usage: python scripts/preprocess_eval_data.py Input: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl """ from __future__ import annotations import json import sys from pathlib import Path # Add project root to path so we can import psq_rag _REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_REPO_ROOT)) from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications # Tags that are annotation artifacts, not real content tags GARBAGE_TAGS = frozenset({ "invalid_tag", "invalid_background", }) INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl") def flatten_ground_truth(tags_categorized_str: str) -> set[str]: """Parse the categorized ground-truth JSON into a flat set of tags.""" if not tags_categorized_str: return set() cats = json.loads(tags_categorized_str) tags = set() for tag_list in cats.values(): if isinstance(tag_list, list): for t in tag_list: tags.add(t.strip()) return tags def main() -> int: if not INPUT_PATH.is_file(): print(f"ERROR: Input not found: {INPUT_PATH}") return 1 # Pre-warm implication graph impl = get_tag_implications() print(f"Loaded {sum(len(v) for v in impl.values())} active implications") samples_read = 0 samples_expanded = 0 total_tags_added = 0 total_garbage_removed = 0 with INPUT_PATH.open("r", encoding="utf-8") as fin, \ OUTPUT_PATH.open("w", encoding="utf-8") as fout: for line in fin: row = json.loads(line) samples_read += 1 gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", "")) # Remove garbage tags garbage_found = gt_raw & GARBAGE_TAGS if garbage_found: total_garbage_removed += len(garbage_found) gt_raw -= garbage_found # Expand through implications gt_expanded, implied_only = expand_tags_via_implications(gt_raw) if implied_only: samples_expanded += 1 total_tags_added += len(implied_only) # Store expanded flat list alongside original categorized field row["tags_ground_truth_expanded"] = sorted(gt_expanded) fout.write(json.dumps(row, ensure_ascii=False) + "\n") print(f"Processed {samples_read} samples") print(f" {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)") print(f" {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)") print(f" {total_garbage_removed} garbage tags removed") print(f"Output: {OUTPUT_PATH}") return 0 if __name__ == "__main__": sys.exit(main())