Spaces:
Running
Running
| """Preprocess eval dataset: expand ground-truth tags through implication chains. | |
| Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag | |
| implication graph, removes known garbage tags, and writes a new JSONL with | |
| an additional `tags_ground_truth_expanded` field (flat sorted list). | |
| The original `tags_ground_truth_categorized` field is preserved unchanged. | |
| Usage: | |
| python scripts/preprocess_eval_data.py | |
| Input: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl | |
| Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path so we can import psq_rag | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications | |
| # Tags that are annotation artifacts, not real content tags | |
| GARBAGE_TAGS = frozenset({ | |
| "invalid_tag", | |
| "invalid_background", | |
| }) | |
| INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" | |
| OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl") | |
| def flatten_ground_truth(tags_categorized_str: str) -> set[str]: | |
| """Parse the categorized ground-truth JSON into a flat set of tags.""" | |
| if not tags_categorized_str: | |
| return set() | |
| cats = json.loads(tags_categorized_str) | |
| tags = set() | |
| for tag_list in cats.values(): | |
| if isinstance(tag_list, list): | |
| for t in tag_list: | |
| tags.add(t.strip()) | |
| return tags | |
| def main() -> int: | |
| if not INPUT_PATH.is_file(): | |
| print(f"ERROR: Input not found: {INPUT_PATH}") | |
| return 1 | |
| # Pre-warm implication graph | |
| impl = get_tag_implications() | |
| print(f"Loaded {sum(len(v) for v in impl.values())} active implications") | |
| samples_read = 0 | |
| samples_expanded = 0 | |
| total_tags_added = 0 | |
| total_garbage_removed = 0 | |
| with INPUT_PATH.open("r", encoding="utf-8") as fin, \ | |
| OUTPUT_PATH.open("w", encoding="utf-8") as fout: | |
| for line in fin: | |
| row = json.loads(line) | |
| samples_read += 1 | |
| gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", "")) | |
| # Remove garbage tags | |
| garbage_found = gt_raw & GARBAGE_TAGS | |
| if garbage_found: | |
| total_garbage_removed += len(garbage_found) | |
| gt_raw -= garbage_found | |
| # Expand through implications | |
| gt_expanded, implied_only = expand_tags_via_implications(gt_raw) | |
| if implied_only: | |
| samples_expanded += 1 | |
| total_tags_added += len(implied_only) | |
| # Store expanded flat list alongside original categorized field | |
| row["tags_ground_truth_expanded"] = sorted(gt_expanded) | |
| fout.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| print(f"Processed {samples_read} samples") | |
| print(f" {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)") | |
| print(f" {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)") | |
| print(f" {total_garbage_removed} garbage tags removed") | |
| print(f"Output: {OUTPUT_PATH}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |