Claude commited on
Commit
f1b4da2
·
1 Parent(s): 133d74c

Add independent character tag metrics to eval pipeline

Browse files

Character and general tags now tracked and reported separately:
- Character: retrieval recall, selection P/R/F1, missed/false-positive lists
- General: selection P/R/F1 (non-character, non-copyright)
- Detects spurious character selections (selected character with none in GT)
- Per-sample output shows character breakdown inline
- JSONL output includes all per-type fields for analysis

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (1) hide show
  1. scripts/eval_pipeline.py +169 -9
scripts/eval_pipeline.py CHANGED
@@ -51,6 +51,28 @@ os.chdir(_REPO_ROOT)
51
 
52
  EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def _flatten_ground_truth_tags(tags_categorized_str: str) -> Set[str]:
56
  """Parse the categorized ground-truth JSON string into a flat set of tags."""
@@ -78,11 +100,25 @@ class SampleResult:
78
  # Stage 2
79
  retrieved_tags: Set[str] = field(default_factory=set)
80
  retrieval_recall: float = 0.0
81
- # Stage 3
82
  selected_tags: Set[str] = field(default_factory=set)
83
  selection_precision: float = 0.0
84
  selection_recall: float = 0.0
85
  selection_f1: float = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Timing
87
  stage1_time: float = 0.0
88
  stage2_time: float = 0.0
@@ -125,6 +161,7 @@ def run_eval(
125
  from psq_rag.llm.rewrite import llm_rewrite_prompt
126
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
127
  from psq_rag.llm.select import llm_select_indices
 
128
 
129
  def log(msg: str) -> None:
130
  if verbose:
@@ -240,16 +277,45 @@ def run_eval(
240
 
241
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
242
 
243
- # Selection metrics
244
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
245
  result.selection_precision = p
246
  result.selection_recall = r
247
  result.selection_f1 = f1
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  print(
250
  f" retrieval_recall={result.retrieval_recall:.3f} "
251
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
252
- f"selected={len(result.selected_tags)} "
253
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
254
  )
255
 
@@ -262,6 +328,10 @@ def run_eval(
262
  return results
263
 
264
 
 
 
 
 
265
  def print_summary(results: List[SampleResult]) -> None:
266
  """Print aggregate metrics across all samples."""
267
  valid = [r for r in results if r.error is None]
@@ -287,21 +357,96 @@ def print_summary(results: List[SampleResult]) -> None:
287
  avg_t3 = sum(r.stage3_time for r in valid) / n
288
 
289
  print()
290
- print("=" * 60)
291
  print(f"EVALUATION SUMMARY ({n} samples, {len(errored)} errors)")
292
- print("=" * 60)
293
  print()
294
  print("Stage 2 - Retrieval:")
295
  print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
296
  print(f" Avg candidates: {avg_retrieved:.1f}")
297
  print()
298
- print("Stage 3 - Selection (final output):")
299
  print(f" Avg precision: {avg_sel_precision:.4f}")
300
  print(f" Avg recall: {avg_sel_recall:.4f}")
301
  print(f" Avg F1: {avg_sel_f1:.4f}")
302
  print(f" Avg selected tags: {avg_selected:.1f}")
303
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  print()
 
305
  print("Timing (avg per sample):")
306
  print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
307
  print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
@@ -311,7 +456,7 @@ def print_summary(results: List[SampleResult]) -> None:
311
 
312
  # Show worst and best F1 samples
313
  by_f1 = sorted(valid, key=lambda r: r.selection_f1)
314
- print("Lowest F1 samples:")
315
  for r in by_f1[:3]:
316
  print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
317
  missed = r.ground_truth_tags - r.selected_tags
@@ -322,7 +467,7 @@ def print_summary(results: List[SampleResult]) -> None:
322
  print(f" extra: {sorted(extra)[:10]}")
323
 
324
  print()
325
- print("Highest F1 samples:")
326
  for r in by_f1[-3:]:
327
  print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
328
 
@@ -332,7 +477,7 @@ def print_summary(results: List[SampleResult]) -> None:
332
  for r in errored[:5]:
333
  print(f" id={r.sample_id}: {r.error}")
334
 
335
- print("=" * 60)
336
 
337
 
338
  def main(argv=None) -> int:
@@ -423,6 +568,21 @@ def main(argv=None) -> int:
423
  "selection_precision": round(r.selection_precision, 4),
424
  "selection_recall": round(r.selection_recall, 4),
425
  "selection_f1": round(r.selection_f1, 4),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  "stage1_time": round(r.stage1_time, 3),
427
  "stage2_time": round(r.stage2_time, 3),
428
  "stage3_time": round(r.stage3_time, 3),
 
51
 
52
  EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
53
 
54
+ # Character tag types that go through the alias filter pipeline
55
+ _CHARACTER_TYPES = {"character"}
56
+ # Copyright tags are filtered out entirely
57
+ _COPYRIGHT_TYPES = {"copyright"}
58
+
59
+
60
+ def _classify_tags(tags: Set[str], get_type_fn) -> Tuple[Set[str], Set[str]]:
61
+ """Split tags into (character_tags, general_tags).
62
+
63
+ Copyright tags are excluded from both sets since they're filtered
64
+ before any selection happens.
65
+ """
66
+ character = set()
67
+ general = set()
68
+ for tag in tags:
69
+ ttype = get_type_fn(tag)
70
+ if ttype in _CHARACTER_TYPES:
71
+ character.add(tag)
72
+ elif ttype not in _COPYRIGHT_TYPES:
73
+ general.add(tag)
74
+ return character, general
75
+
76
 
77
  def _flatten_ground_truth_tags(tags_categorized_str: str) -> Set[str]:
78
  """Parse the categorized ground-truth JSON string into a flat set of tags."""
 
100
  # Stage 2
101
  retrieved_tags: Set[str] = field(default_factory=set)
102
  retrieval_recall: float = 0.0
103
+ # Stage 3 — overall
104
  selected_tags: Set[str] = field(default_factory=set)
105
  selection_precision: float = 0.0
106
  selection_recall: float = 0.0
107
  selection_f1: float = 0.0
108
+ # Stage 3 — character tags only
109
+ gt_character_tags: Set[str] = field(default_factory=set)
110
+ selected_character_tags: Set[str] = field(default_factory=set)
111
+ retrieved_character_tags: Set[str] = field(default_factory=set)
112
+ char_retrieval_recall: float = 0.0
113
+ char_precision: float = 0.0
114
+ char_recall: float = 0.0
115
+ char_f1: float = 0.0
116
+ # Stage 3 — general tags only (non-character, non-copyright)
117
+ gt_general_tags: Set[str] = field(default_factory=set)
118
+ selected_general_tags: Set[str] = field(default_factory=set)
119
+ general_precision: float = 0.0
120
+ general_recall: float = 0.0
121
+ general_f1: float = 0.0
122
  # Timing
123
  stage1_time: float = 0.0
124
  stage2_time: float = 0.0
 
161
  from psq_rag.llm.rewrite import llm_rewrite_prompt
162
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
163
  from psq_rag.llm.select import llm_select_indices
164
+ from psq_rag.retrieval.state import get_tag_type_name
165
 
166
  def log(msg: str) -> None:
167
  if verbose:
 
277
 
278
  result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
279
 
280
+ # Overall selection metrics
281
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
282
  result.selection_precision = p
283
  result.selection_recall = r
284
  result.selection_f1 = f1
285
 
286
+ # Split ground-truth and selected tags by type
287
+ gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
288
+ sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
289
+ ret_char, _ = _classify_tags(result.retrieved_tags, get_tag_type_name)
290
+
291
+ result.gt_character_tags = gt_char
292
+ result.selected_character_tags = sel_char
293
+ result.retrieved_character_tags = ret_char
294
+ result.gt_general_tags = gt_gen
295
+ result.selected_general_tags = sel_gen
296
+
297
+ # Character-specific metrics
298
+ if gt_char:
299
+ result.char_retrieval_recall = len(ret_char & gt_char) / len(gt_char)
300
+ cp, cr, cf1 = _compute_metrics(sel_char, gt_char)
301
+ result.char_precision = cp
302
+ result.char_recall = cr
303
+ result.char_f1 = cf1
304
+
305
+ # General-tag metrics
306
+ gp, gr, gf1 = _compute_metrics(sel_gen, gt_gen)
307
+ result.general_precision = gp
308
+ result.general_recall = gr
309
+ result.general_f1 = gf1
310
+
311
+ # Per-sample output line
312
+ char_info = ""
313
+ if gt_char:
314
+ char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
315
  print(
316
  f" retrieval_recall={result.retrieval_recall:.3f} "
317
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
318
+ f"selected={len(result.selected_tags)}{char_info} "
319
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
320
  )
321
 
 
328
  return results
329
 
330
 
331
+ def _safe_avg(values: List[float]) -> float:
332
+ return sum(values) / len(values) if values else 0.0
333
+
334
+
335
  def print_summary(results: List[SampleResult]) -> None:
336
  """Print aggregate metrics across all samples."""
337
  valid = [r for r in results if r.error is None]
 
357
  avg_t3 = sum(r.stage3_time for r in valid) / n
358
 
359
  print()
360
+ print("=" * 70)
361
  print(f"EVALUATION SUMMARY ({n} samples, {len(errored)} errors)")
362
+ print("=" * 70)
363
  print()
364
  print("Stage 2 - Retrieval:")
365
  print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
366
  print(f" Avg candidates: {avg_retrieved:.1f}")
367
  print()
368
+ print("Stage 3 - Selection (ALL tags):")
369
  print(f" Avg precision: {avg_sel_precision:.4f}")
370
  print(f" Avg recall: {avg_sel_recall:.4f}")
371
  print(f" Avg F1: {avg_sel_f1:.4f}")
372
  print(f" Avg selected tags: {avg_selected:.1f}")
373
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
374
+
375
+ # --- Character tag breakdown ---
376
+ # Only include samples that actually have character tags in ground truth
377
+ samples_with_chars = [r for r in valid if r.gt_character_tags]
378
+ # Samples where the system selected character tags (true or false positive)
379
+ samples_selecting_chars = [r for r in valid if r.selected_character_tags]
380
+
381
+ print()
382
+ print("-" * 70)
383
+ print(f"CHARACTER TAGS ({len(samples_with_chars)}/{n} samples have character ground-truth)")
384
+ print("-" * 70)
385
+
386
+ if samples_with_chars:
387
+ avg_char_retrieval_recall = _safe_avg([r.char_retrieval_recall for r in samples_with_chars])
388
+ avg_char_p = _safe_avg([r.char_precision for r in samples_with_chars])
389
+ avg_char_r = _safe_avg([r.char_recall for r in samples_with_chars])
390
+ avg_char_f1 = _safe_avg([r.char_f1 for r in samples_with_chars])
391
+ avg_gt_char = _safe_avg([len(r.gt_character_tags) for r in samples_with_chars])
392
+ avg_sel_char = _safe_avg([len(r.selected_character_tags) for r in samples_with_chars])
393
+
394
+ print(f" Retrieval recall: {avg_char_retrieval_recall:.4f}")
395
+ print(f" Selection precision: {avg_char_p:.4f}")
396
+ print(f" Selection recall: {avg_char_r:.4f}")
397
+ print(f" Selection F1: {avg_char_f1:.4f}")
398
+ print(f" Avg gt char tags: {avg_gt_char:.1f}")
399
+ print(f" Avg selected chars: {avg_sel_char:.1f}")
400
+
401
+ # Show character-specific failures
402
+ char_misses = []
403
+ char_false_pos = []
404
+ for r in samples_with_chars:
405
+ missed = r.gt_character_tags - r.selected_character_tags
406
+ for m in missed:
407
+ char_misses.append((r.sample_id, m))
408
+ extra = r.selected_character_tags - r.gt_character_tags
409
+ for e in extra:
410
+ char_false_pos.append((r.sample_id, e))
411
+
412
+ if char_misses:
413
+ print(f"\n Missed characters ({len(char_misses)} total):")
414
+ for sid, tag in char_misses[:10]:
415
+ print(f" id={sid}: missed {tag}")
416
+
417
+ if char_false_pos:
418
+ print(f"\n False positive characters ({len(char_false_pos)} total):")
419
+ for sid, tag in char_false_pos[:10]:
420
+ print(f" id={sid}: wrongly selected {tag}")
421
+ else:
422
+ print(" (no samples had character tags in ground truth)")
423
+
424
+ # False positive characters in samples WITHOUT character ground-truth
425
+ no_char_gt_but_selected = [r for r in valid if not r.gt_character_tags and r.selected_character_tags]
426
+ if no_char_gt_but_selected:
427
+ print(f"\n Spurious character selections ({len(no_char_gt_but_selected)} samples):")
428
+ print(" (These samples had NO character in ground truth but system selected one)")
429
+ for r in no_char_gt_but_selected[:5]:
430
+ print(f" id={r.sample_id}: selected {sorted(r.selected_character_tags)}")
431
+
432
+ # --- General tag breakdown ---
433
+ print()
434
+ print("-" * 70)
435
+ print("GENERAL TAGS (non-character, non-copyright)")
436
+ print("-" * 70)
437
+ avg_gen_p = _safe_avg([r.general_precision for r in valid])
438
+ avg_gen_r = _safe_avg([r.general_recall for r in valid])
439
+ avg_gen_f1 = _safe_avg([r.general_f1 for r in valid])
440
+ avg_gt_gen = _safe_avg([len(r.gt_general_tags) for r in valid])
441
+ avg_sel_gen = _safe_avg([len(r.selected_general_tags) for r in valid])
442
+ print(f" Selection precision: {avg_gen_p:.4f}")
443
+ print(f" Selection recall: {avg_gen_r:.4f}")
444
+ print(f" Selection F1: {avg_gen_f1:.4f}")
445
+ print(f" Avg gt general tags: {avg_gt_gen:.1f}")
446
+ print(f" Avg selected general: {avg_sel_gen:.1f}")
447
+
448
  print()
449
+ print("-" * 70)
450
  print("Timing (avg per sample):")
451
  print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
452
  print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
 
456
 
457
  # Show worst and best F1 samples
458
  by_f1 = sorted(valid, key=lambda r: r.selection_f1)
459
+ print("Lowest F1 samples (overall):")
460
  for r in by_f1[:3]:
461
  print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
462
  missed = r.ground_truth_tags - r.selected_tags
 
467
  print(f" extra: {sorted(extra)[:10]}")
468
 
469
  print()
470
+ print("Highest F1 samples (overall):")
471
  for r in by_f1[-3:]:
472
  print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
473
 
 
477
  for r in errored[:5]:
478
  print(f" id={r.sample_id}: {r.error}")
479
 
480
+ print("=" * 70)
481
 
482
 
483
  def main(argv=None) -> int:
 
568
  "selection_precision": round(r.selection_precision, 4),
569
  "selection_recall": round(r.selection_recall, 4),
570
  "selection_f1": round(r.selection_f1, 4),
571
+ # Character tag breakdown
572
+ "gt_character_tags": sorted(r.gt_character_tags),
573
+ "selected_character_tags": sorted(r.selected_character_tags),
574
+ "retrieved_character_tags": sorted(r.retrieved_character_tags),
575
+ "char_retrieval_recall": round(r.char_retrieval_recall, 4),
576
+ "char_precision": round(r.char_precision, 4),
577
+ "char_recall": round(r.char_recall, 4),
578
+ "char_f1": round(r.char_f1, 4),
579
+ # General tag breakdown
580
+ "gt_general_tags": sorted(r.gt_general_tags),
581
+ "selected_general_tags": sorted(r.selected_general_tags),
582
+ "general_precision": round(r.general_precision, 4),
583
+ "general_recall": round(r.general_recall, 4),
584
+ "general_f1": round(r.general_f1, 4),
585
+ # Timing
586
  "stage1_time": round(r.stage1_time, 3),
587
  "stage2_time": round(r.stage2_time, 3),
588
  "stage3_time": round(r.stage3_time, 3),