Claude commited on
Commit
14e5c38
·
1 Parent(s): 6fc4b56

Normalize GT annotations: expand implications, exclude non-evaluable tags

Browse files

Addresses annotation inconsistency where 30% of GT samples were missing
implied taxonomy tags (e.g. fox present but canid/mammal absent).

- preprocess_eval_data.py: expands GT through implication graph, writes
_expanded.jsonl with tags_ground_truth_expanded field
- eval_pipeline.py: uses expanded GT, strips _EVAL_EXCLUDED_TAGS
(invalid_*, hi_res, structural backgrounds) from both sides,
reports leaf-only metrics alongside expanded metrics
- state.py: adds get_leaf_tags() to strip implied ancestors from a tag set

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
psq_rag/retrieval/state.py CHANGED
@@ -327,6 +327,29 @@ def expand_tags_via_implications(tags: Set[str]) -> Tuple[Set[str], Set[str]]:
327
  return expanded, implied_only
328
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def get_tfidf_tag_vectors() -> Dict[str, Any]:
331
  global _tfidf_tag_vectors
332
  if _tfidf_tag_vectors is not None:
 
327
  return expanded, implied_only
328
 
329
 
330
+ def get_leaf_tags(tags: Set[str]) -> Set[str]:
331
+ """Return only leaf tags — those not implied by any other tag in the set.
332
+
333
+ For example, given {fox, canine, canid, mammal}, returns {fox} because
334
+ canine/canid/mammal are all reachable from fox via implications.
335
+ """
336
+ impl = get_tag_implications()
337
+ # For each tag, compute what it implies; mark those as non-leaves
338
+ non_leaves: Set[str] = set()
339
+ for tag in tags:
340
+ visited: Set[str] = set()
341
+ queue = [tag]
342
+ while queue:
343
+ t = queue.pop()
344
+ for parent in impl.get(t, ()):
345
+ if parent not in visited:
346
+ visited.add(parent)
347
+ if parent in tags:
348
+ non_leaves.add(parent)
349
+ queue.append(parent)
350
+ return tags - non_leaves
351
+
352
+
353
  def get_tfidf_tag_vectors() -> Dict[str, Any]:
354
  global _tfidf_tag_vectors
355
  if _tfidf_tag_vectors is not None:
scripts/eval_pipeline.py CHANGED
@@ -57,13 +57,29 @@ if str(_REPO_ROOT) not in sys.path:
57
  sys.path.insert(0, str(_REPO_ROOT))
58
  os.chdir(_REPO_ROOT)
59
 
60
- EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
 
61
 
62
  # Character tag types that go through the alias filter pipeline
63
  _CHARACTER_TYPES = {"character"}
64
  # Copyright tags are filtered out entirely
65
  _COPYRIGHT_TYPES = {"copyright"}
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def _classify_tags(tags: Set[str], get_type_fn) -> Tuple[Set[str], Set[str]]:
69
  """Split tags into (character_tags, general_tags).
@@ -135,6 +151,12 @@ class SampleResult:
135
  why_counts: Dict[str, int] = field(default_factory=dict)
136
  # Tag implications
137
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
 
 
 
 
 
 
138
  # Timing
139
  stage1_time: float = 0.0
140
  stage2_time: float = 0.0
@@ -179,7 +201,7 @@ def _process_one_sample(
179
  from psq_rag.llm.rewrite import llm_rewrite_prompt
180
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
181
  from psq_rag.llm.select import llm_select_indices
182
- from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications
183
 
184
  def log(msg: str) -> None:
185
  if verbose:
@@ -273,13 +295,27 @@ def _process_one_sample(
273
  result.selected_tags = expanded
274
  log(f"Implications: +{len(implied_only)} tags")
275
 
276
- # Overall selection metrics
 
 
 
 
277
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
278
  result.selection_precision = p
279
  result.selection_recall = r
280
  result.selection_f1 = f1
281
 
282
- # New diagnostic metrics
 
 
 
 
 
 
 
 
 
 
283
  retrieved_and_gt = result.retrieved_tags & gt_tags
284
  selected_and_gt = result.selected_tags & gt_tags
285
  if result.retrieved_tags:
@@ -370,26 +406,41 @@ def run_eval(
370
  expand_implications: bool = False,
371
  ) -> List[SampleResult]:
372
 
373
- # Load eval samples
374
- if not EVAL_DATA_PATH.is_file():
375
- print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
376
- sys.exit(1)
 
 
 
 
 
377
 
378
  all_samples = []
379
- with EVAL_DATA_PATH.open("r", encoding="utf-8") as f:
 
380
  for line in f:
381
  row = json.loads(line)
382
  caption = row.get(caption_field, "")
383
  if not caption or not caption.strip():
384
  continue
385
- gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
 
 
 
 
 
386
  if not gt_tags:
387
  continue
 
 
388
  all_samples.append({
389
  "id": row.get("id", row.get("row_id", len(all_samples))),
390
  "caption": caption.strip(),
391
  "gt_tags": gt_tags,
392
  })
 
 
393
 
394
  if shuffle:
395
  rng = random.Random(seed)
@@ -512,6 +563,21 @@ def print_summary(results: List[SampleResult]) -> None:
512
  if avg_implied > 0:
513
  print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
514
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  print()
516
  print("Diagnostic Metrics:")
517
  print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
@@ -761,6 +827,12 @@ def main(argv=None) -> int:
761
  "over_selection_ratio": round(r.over_selection_ratio, 2),
762
  "why_counts": r.why_counts,
763
  "implied_tags": sorted(r.implied_tags),
 
 
 
 
 
 
764
  # Timing
765
  "stage1_time": round(r.stage1_time, 3),
766
  "stage2_time": round(r.stage2_time, 3),
 
57
  sys.path.insert(0, str(_REPO_ROOT))
58
  os.chdir(_REPO_ROOT)
59
 
60
+ EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl"
61
+ EVAL_DATA_PATH_RAW = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
62
 
63
  # Character tag types that go through the alias filter pipeline
64
  _CHARACTER_TYPES = {"character"}
65
  # Copyright tags are filtered out entirely
66
  _COPYRIGHT_TYPES = {"copyright"}
67
 
68
+ # Tags excluded from evaluation metrics but NOT removed from the pipeline.
69
+ # These are tags that either: can't be inferred from a caption (resolution,
70
+ # art medium), describe structural properties better handled outside the
71
+ # retrieval pipeline (backgrounds), or are annotation artifacts.
72
+ _EVAL_EXCLUDED_TAGS = frozenset({
73
+ # Annotation artifacts
74
+ "invalid_tag", "invalid_background",
75
+ # Resolution / file meta — not inferrable from caption
76
+ "hi_res", "absurd_res", "low_res", "superabsurd_res",
77
+ # Structural background tags — better recommended independently
78
+ "simple_background", "abstract_background", "detailed_background",
79
+ "gradient_background", "blurred_background", "textured_background",
80
+ "transparent_background", "white_background",
81
+ })
82
+
83
 
84
  def _classify_tags(tags: Set[str], get_type_fn) -> Tuple[Set[str], Set[str]]:
85
  """Split tags into (character_tags, general_tags).
 
151
  why_counts: Dict[str, int] = field(default_factory=dict)
152
  # Tag implications
153
  implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
154
+ # Leaf-only metrics (strips implied ancestors from both sides)
155
+ leaf_precision: float = 0.0
156
+ leaf_recall: float = 0.0
157
+ leaf_f1: float = 0.0
158
+ leaf_selected_count: int = 0
159
+ leaf_gt_count: int = 0
160
  # Timing
161
  stage1_time: float = 0.0
162
  stage2_time: float = 0.0
 
201
  from psq_rag.llm.rewrite import llm_rewrite_prompt
202
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
203
  from psq_rag.llm.select import llm_select_indices
204
+ from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
205
 
206
  def log(msg: str) -> None:
207
  if verbose:
 
295
  result.selected_tags = expanded
296
  log(f"Implications: +{len(implied_only)} tags")
297
 
298
+ # Remove eval-excluded tags from predictions before scoring
299
+ result.selected_tags -= _EVAL_EXCLUDED_TAGS
300
+ result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
301
+
302
+ # Overall selection metrics (expanded — both sides have full implication chains)
303
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
304
  result.selection_precision = p
305
  result.selection_recall = r
306
  result.selection_f1 = f1
307
 
308
+ # Leaf-only metrics (strips implied ancestors from both sides)
309
+ leaf_sel = get_leaf_tags(result.selected_tags)
310
+ leaf_gt = get_leaf_tags(gt_tags)
311
+ lp, lr, lf1 = _compute_metrics(leaf_sel, leaf_gt)
312
+ result.leaf_precision = lp
313
+ result.leaf_recall = lr
314
+ result.leaf_f1 = lf1
315
+ result.leaf_selected_count = len(leaf_sel)
316
+ result.leaf_gt_count = len(leaf_gt)
317
+
318
+ # Diagnostic metrics
319
  retrieved_and_gt = result.retrieved_tags & gt_tags
320
  selected_and_gt = result.selected_tags & gt_tags
321
  if result.retrieved_tags:
 
406
  expand_implications: bool = False,
407
  ) -> List[SampleResult]:
408
 
409
+ # Load eval samples — prefer expanded file, fall back to raw
410
+ eval_path = EVAL_DATA_PATH
411
+ if not eval_path.is_file():
412
+ eval_path = EVAL_DATA_PATH_RAW
413
+ if not eval_path.is_file():
414
+ print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
415
+ sys.exit(1)
416
+ print(f"WARNING: Expanded eval data not found, falling back to raw: {eval_path}")
417
+ print(" Run: python scripts/preprocess_eval_data.py")
418
 
419
  all_samples = []
420
+ using_expanded = False
421
+ with eval_path.open("r", encoding="utf-8") as f:
422
  for line in f:
423
  row = json.loads(line)
424
  caption = row.get(caption_field, "")
425
  if not caption or not caption.strip():
426
  continue
427
+ # Prefer pre-expanded GT; fall back to flattening categorized
428
+ if "tags_ground_truth_expanded" in row:
429
+ gt_tags = set(row["tags_ground_truth_expanded"])
430
+ using_expanded = True
431
+ else:
432
+ gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
433
  if not gt_tags:
434
  continue
435
+ # Remove eval-excluded tags from GT
436
+ gt_tags -= _EVAL_EXCLUDED_TAGS
437
  all_samples.append({
438
  "id": row.get("id", row.get("row_id", len(all_samples))),
439
  "caption": caption.strip(),
440
  "gt_tags": gt_tags,
441
  })
442
+ if using_expanded:
443
+ print("Using implication-expanded ground truth")
444
 
445
  if shuffle:
446
  rng = random.Random(seed)
 
563
  if avg_implied > 0:
564
  print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
565
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
566
+
567
+ # Leaf-only metrics
568
+ avg_leaf_p = _safe_avg([r.leaf_precision for r in valid])
569
+ avg_leaf_r = _safe_avg([r.leaf_recall for r in valid])
570
+ avg_leaf_f1 = _safe_avg([r.leaf_f1 for r in valid])
571
+ avg_leaf_sel = _safe_avg([r.leaf_selected_count for r in valid])
572
+ avg_leaf_gt = _safe_avg([r.leaf_gt_count for r in valid])
573
+ print()
574
+ print("Stage 3 - Selection (LEAF tags only — implied ancestors stripped):")
575
+ print(f" Avg precision: {avg_leaf_p:.4f}")
576
+ print(f" Avg recall: {avg_leaf_r:.4f}")
577
+ print(f" Avg F1: {avg_leaf_f1:.4f}")
578
+ print(f" Avg leaf selected: {avg_leaf_sel:.1f}")
579
+ print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
580
+
581
  print()
582
  print("Diagnostic Metrics:")
583
  print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
 
827
  "over_selection_ratio": round(r.over_selection_ratio, 2),
828
  "why_counts": r.why_counts,
829
  "implied_tags": sorted(r.implied_tags),
830
+ # Leaf metrics
831
+ "leaf_precision": round(r.leaf_precision, 4),
832
+ "leaf_recall": round(r.leaf_recall, 4),
833
+ "leaf_f1": round(r.leaf_f1, 4),
834
+ "leaf_selected_count": r.leaf_selected_count,
835
+ "leaf_gt_count": r.leaf_gt_count,
836
  # Timing
837
  "stage1_time": round(r.stage1_time, 3),
838
  "stage2_time": round(r.stage2_time, 3),
scripts/preprocess_eval_data.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Preprocess eval dataset: expand ground-truth tags through implication chains.
2
+
3
+ Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag
4
+ implication graph, removes known garbage tags, and writes a new JSONL with
5
+ an additional `tags_ground_truth_expanded` field (flat sorted list).
6
+
7
+ The original `tags_ground_truth_categorized` field is preserved unchanged.
8
+
9
+ Usage:
10
+ python scripts/preprocess_eval_data.py
11
+
12
+ Input: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl
13
+ Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import sys
20
+ from pathlib import Path
21
+
22
+ # Add project root to path so we can import psq_rag
23
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
24
+ sys.path.insert(0, str(_REPO_ROOT))
25
+
26
+ from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications
27
+
28
+ # Tags that are annotation artifacts, not real content tags
29
+ GARBAGE_TAGS = frozenset({
30
+ "invalid_tag",
31
+ "invalid_background",
32
+ })
33
+
34
+ INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
35
+ OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl")
36
+
37
+
38
+ def flatten_ground_truth(tags_categorized_str: str) -> set[str]:
39
+ """Parse the categorized ground-truth JSON into a flat set of tags."""
40
+ if not tags_categorized_str:
41
+ return set()
42
+ cats = json.loads(tags_categorized_str)
43
+ tags = set()
44
+ for tag_list in cats.values():
45
+ if isinstance(tag_list, list):
46
+ for t in tag_list:
47
+ tags.add(t.strip())
48
+ return tags
49
+
50
+
51
+ def main() -> int:
52
+ if not INPUT_PATH.is_file():
53
+ print(f"ERROR: Input not found: {INPUT_PATH}")
54
+ return 1
55
+
56
+ # Pre-warm implication graph
57
+ impl = get_tag_implications()
58
+ print(f"Loaded {sum(len(v) for v in impl.values())} active implications")
59
+
60
+ samples_read = 0
61
+ samples_expanded = 0
62
+ total_tags_added = 0
63
+ total_garbage_removed = 0
64
+
65
+ with INPUT_PATH.open("r", encoding="utf-8") as fin, \
66
+ OUTPUT_PATH.open("w", encoding="utf-8") as fout:
67
+ for line in fin:
68
+ row = json.loads(line)
69
+ samples_read += 1
70
+
71
+ gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", ""))
72
+
73
+ # Remove garbage tags
74
+ garbage_found = gt_raw & GARBAGE_TAGS
75
+ if garbage_found:
76
+ total_garbage_removed += len(garbage_found)
77
+ gt_raw -= garbage_found
78
+
79
+ # Expand through implications
80
+ gt_expanded, implied_only = expand_tags_via_implications(gt_raw)
81
+ if implied_only:
82
+ samples_expanded += 1
83
+ total_tags_added += len(implied_only)
84
+
85
+ # Store expanded flat list alongside original categorized field
86
+ row["tags_ground_truth_expanded"] = sorted(gt_expanded)
87
+
88
+ fout.write(json.dumps(row, ensure_ascii=False) + "\n")
89
+
90
+ print(f"Processed {samples_read} samples")
91
+ print(f" {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)")
92
+ print(f" {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)")
93
+ print(f" {total_garbage_removed} garbage tags removed")
94
+ print(f"Output: {OUTPUT_PATH}")
95
+ return 0
96
+
97
+
98
+ if __name__ == "__main__":
99
+ sys.exit(main())