Claude commited on
Commit
6909d06
·
1 Parent(s): ea9e11c

Add end-to-end evaluation harness for pipeline metrics

Browse files

scripts/eval_pipeline.py measures per-stage and overall quality:
- Stage 2 retrieval recall@k (fraction of ground-truth tags retrieved)
- Stage 3 selection precision, recall, F1 (final output vs ground truth)
- Per-sample timing for each stage
- Summary with worst/best F1 samples and missed/extra tag analysis

Uses e621_sfw_sample_1000 eval dataset with multiple caption fields.
Supports --skip-rewrite mode and JSONL output for detailed analysis.

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (1) hide show
  1. scripts/eval_pipeline.py +391 -0
scripts/eval_pipeline.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end evaluation harness for the Prompt Squirrel RAG pipeline.
2
+
3
+ Measures per-stage and overall metrics using ground-truth tagged samples
4
+ from the e621 evaluation dataset.
5
+
6
+ Metrics computed:
7
+ - Stage 2 (Retrieval): Recall@k — what fraction of ground-truth tags
8
+ appear among the retrieved candidates
9
+ - Stage 3 (Selection): Precision, Recall, F1 — how well the final
10
+ selected tags match the ground truth
11
+
12
+ Usage:
13
+ # Full end-to-end (Stage 1 + 2 + 3):
14
+ python scripts/eval_pipeline.py --n 20
15
+
16
+ # Skip Stage 1 LLM rewrite, use ground-truth tags as retrieval input:
17
+ python scripts/eval_pipeline.py --n 20 --skip-rewrite
18
+
19
+ # Use a specific caption field:
20
+ python scripts/eval_pipeline.py --n 20 --caption-field caption_cogvlm
21
+
22
+ Requires:
23
+ - OPENROUTER_API_KEY env var (for Stage 1 rewrite and Stage 3 selection)
24
+ - fluffyrock_3m.csv and other retrieval assets in the project root
25
+ - data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import json
32
+ import os
33
+ import sys
34
+ import time
35
+ from dataclasses import dataclass, field
36
+ from pathlib import Path
37
+ from typing import Any, Dict, List, Optional, Set, Tuple
38
+
39
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
40
+ if str(_REPO_ROOT) not in sys.path:
41
+ sys.path.insert(0, str(_REPO_ROOT))
42
+ os.chdir(_REPO_ROOT)
43
+
44
+ EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
45
+
46
+
47
+ def _flatten_ground_truth_tags(tags_categorized_str: str) -> Set[str]:
48
+ """Parse the categorized ground-truth JSON string into a flat set of tags."""
49
+ if not tags_categorized_str:
50
+ return set()
51
+ try:
52
+ cats = json.loads(tags_categorized_str)
53
+ except json.JSONDecodeError:
54
+ return set()
55
+ tags = set()
56
+ for tag_list in cats.values():
57
+ if isinstance(tag_list, list):
58
+ for t in tag_list:
59
+ tags.add(t.strip())
60
+ return tags
61
+
62
+
63
+ @dataclass
64
+ class SampleResult:
65
+ sample_id: Any
66
+ caption: str
67
+ ground_truth_tags: Set[str]
68
+ # Stage 1
69
+ rewrite_phrases: List[str] = field(default_factory=list)
70
+ # Stage 2
71
+ retrieved_tags: Set[str] = field(default_factory=set)
72
+ retrieval_recall: float = 0.0
73
+ # Stage 3
74
+ selected_tags: Set[str] = field(default_factory=set)
75
+ selection_precision: float = 0.0
76
+ selection_recall: float = 0.0
77
+ selection_f1: float = 0.0
78
+ # Timing
79
+ stage1_time: float = 0.0
80
+ stage2_time: float = 0.0
81
+ stage3_time: float = 0.0
82
+ # Errors
83
+ error: Optional[str] = None
84
+
85
+
86
+ def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float, float, float]:
87
+ """Compute precision, recall, F1."""
88
+ if not predicted and not ground_truth:
89
+ return 1.0, 1.0, 1.0
90
+ if not predicted:
91
+ return 0.0, 0.0, 0.0
92
+ if not ground_truth:
93
+ return 0.0, 0.0, 0.0
94
+
95
+ tp = len(predicted & ground_truth)
96
+ precision = tp / len(predicted)
97
+ recall = tp / len(ground_truth)
98
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
99
+ return precision, recall, f1
100
+
101
+
102
+ def run_eval(
103
+ n_samples: int = 20,
104
+ caption_field: str = "caption_cogvlm",
105
+ skip_rewrite: bool = False,
106
+ allow_nsfw: bool = False,
107
+ mode: str = "chunked_map_union",
108
+ chunk_size: int = 60,
109
+ per_phrase_k: int = 2,
110
+ temperature: float = 0.0,
111
+ max_tokens: int = 512,
112
+ verbose: bool = False,
113
+ ) -> List[SampleResult]:
114
+
115
+ from psq_rag.llm.rewrite import llm_rewrite_prompt
116
+ from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
117
+ from psq_rag.llm.select import llm_select_indices
118
+
119
+ def log(msg: str) -> None:
120
+ if verbose:
121
+ print(f" {msg}")
122
+
123
+ # Load eval samples
124
+ if not EVAL_DATA_PATH.is_file():
125
+ print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
126
+ sys.exit(1)
127
+
128
+ samples = []
129
+ with EVAL_DATA_PATH.open("r", encoding="utf-8") as f:
130
+ for line in f:
131
+ if len(samples) >= n_samples:
132
+ break
133
+ row = json.loads(line)
134
+ caption = row.get(caption_field, "")
135
+ if not caption or not caption.strip():
136
+ continue
137
+ gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
138
+ if not gt_tags:
139
+ continue
140
+ samples.append({
141
+ "id": row.get("id", row.get("row_id", len(samples))),
142
+ "caption": caption.strip(),
143
+ "gt_tags": gt_tags,
144
+ })
145
+
146
+ print(f"Loaded {len(samples)} samples (caption_field={caption_field})")
147
+ print(f"skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
148
+ print()
149
+
150
+ results: List[SampleResult] = []
151
+
152
+ for i, sample in enumerate(samples):
153
+ sid = sample["id"]
154
+ caption = sample["caption"]
155
+ gt_tags = sample["gt_tags"]
156
+
157
+ result = SampleResult(
158
+ sample_id=sid,
159
+ caption=caption[:120] + ("..." if len(caption) > 120 else ""),
160
+ ground_truth_tags=gt_tags,
161
+ )
162
+
163
+ print(f"[{i+1}/{len(samples)}] id={sid} gt_tags={len(gt_tags)}")
164
+
165
+ try:
166
+ # --- Stage 1: LLM Rewrite ---
167
+ if skip_rewrite:
168
+ # Use the caption directly as comma-separated phrases
169
+ phrases = [p.strip() for p in caption.split(",") if p.strip()]
170
+ # Also split on periods/sentences for natural language captions
171
+ if len(phrases) <= 1:
172
+ phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
173
+ result.rewrite_phrases = phrases
174
+ result.stage1_time = 0.0
175
+ else:
176
+ t0 = time.time()
177
+ rewritten = llm_rewrite_prompt(caption, log)
178
+ result.stage1_time = time.time() - t0
179
+ if rewritten:
180
+ result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()]
181
+ else:
182
+ result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()]
183
+ if len(result.rewrite_phrases) <= 1:
184
+ result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
185
+
186
+ if verbose:
187
+ log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}")
188
+
189
+ # --- Stage 2: Retrieval ---
190
+ t0 = time.time()
191
+ retrieval_result = psq_candidates_from_rewrite_phrases(
192
+ rewrite_phrases=result.rewrite_phrases,
193
+ allow_nsfw_tags=allow_nsfw,
194
+ global_k=300,
195
+ verbose=False,
196
+ )
197
+ result.stage2_time = time.time() - t0
198
+
199
+ if isinstance(retrieval_result, tuple):
200
+ candidates, _ = retrieval_result
201
+ else:
202
+ candidates = retrieval_result
203
+
204
+ result.retrieved_tags = {c.tag for c in candidates}
205
+ # Retrieval recall: what fraction of ground truth was retrieved
206
+ if gt_tags:
207
+ result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags)
208
+
209
+ if verbose:
210
+ log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}")
211
+
212
+ # --- Stage 3: LLM Selection ---
213
+ t0 = time.time()
214
+ picked_indices = llm_select_indices(
215
+ query_text=caption,
216
+ candidates=candidates,
217
+ max_pick=0,
218
+ log=log,
219
+ mode=mode,
220
+ chunk_size=chunk_size,
221
+ per_phrase_k=per_phrase_k,
222
+ temperature=temperature,
223
+ max_tokens=max_tokens,
224
+ )
225
+ result.stage3_time = time.time() - t0
226
+
227
+ result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
228
+
229
+ # Selection metrics
230
+ p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
231
+ result.selection_precision = p
232
+ result.selection_recall = r
233
+ result.selection_f1 = f1
234
+
235
+ print(
236
+ f" retrieval_recall={result.retrieval_recall:.3f} "
237
+ f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
238
+ f"selected={len(result.selected_tags)} "
239
+ f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
240
+ )
241
+
242
+ except Exception as e:
243
+ result.error = str(e)
244
+ print(f" ERROR: {e}")
245
+
246
+ results.append(result)
247
+
248
+ return results
249
+
250
+
251
+ def print_summary(results: List[SampleResult]) -> None:
252
+ """Print aggregate metrics across all samples."""
253
+ valid = [r for r in results if r.error is None]
254
+ errored = [r for r in results if r.error is not None]
255
+
256
+ if not valid:
257
+ print("\nNo valid results to summarize.")
258
+ return
259
+
260
+ n = len(valid)
261
+
262
+ avg_retrieval_recall = sum(r.retrieval_recall for r in valid) / n
263
+ avg_sel_precision = sum(r.selection_precision for r in valid) / n
264
+ avg_sel_recall = sum(r.selection_recall for r in valid) / n
265
+ avg_sel_f1 = sum(r.selection_f1 for r in valid) / n
266
+
267
+ avg_retrieved = sum(len(r.retrieved_tags) for r in valid) / n
268
+ avg_selected = sum(len(r.selected_tags) for r in valid) / n
269
+ avg_gt = sum(len(r.ground_truth_tags) for r in valid) / n
270
+
271
+ avg_t1 = sum(r.stage1_time for r in valid) / n
272
+ avg_t2 = sum(r.stage2_time for r in valid) / n
273
+ avg_t3 = sum(r.stage3_time for r in valid) / n
274
+
275
+ print()
276
+ print("=" * 60)
277
+ print(f"EVALUATION SUMMARY ({n} samples, {len(errored)} errors)")
278
+ print("=" * 60)
279
+ print()
280
+ print("Stage 2 - Retrieval:")
281
+ print(f" Avg recall@300: {avg_retrieval_recall:.4f}")
282
+ print(f" Avg candidates: {avg_retrieved:.1f}")
283
+ print()
284
+ print("Stage 3 - Selection (final output):")
285
+ print(f" Avg precision: {avg_sel_precision:.4f}")
286
+ print(f" Avg recall: {avg_sel_recall:.4f}")
287
+ print(f" Avg F1: {avg_sel_f1:.4f}")
288
+ print(f" Avg selected tags: {avg_selected:.1f}")
289
+ print(f" Avg ground-truth tags:{avg_gt:.1f}")
290
+ print()
291
+ print("Timing (avg per sample):")
292
+ print(f" Stage 1 (rewrite): {avg_t1:.2f}s")
293
+ print(f" Stage 2 (retrieval): {avg_t2:.2f}s")
294
+ print(f" Stage 3 (selection): {avg_t3:.2f}s")
295
+ print(f" Total: {avg_t1 + avg_t2 + avg_t3:.2f}s")
296
+ print()
297
+
298
+ # Show worst and best F1 samples
299
+ by_f1 = sorted(valid, key=lambda r: r.selection_f1)
300
+ print("Lowest F1 samples:")
301
+ for r in by_f1[:3]:
302
+ print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
303
+ missed = r.ground_truth_tags - r.selected_tags
304
+ extra = r.selected_tags - r.ground_truth_tags
305
+ if missed:
306
+ print(f" missed: {sorted(missed)[:10]}")
307
+ if extra:
308
+ print(f" extra: {sorted(extra)[:10]}")
309
+
310
+ print()
311
+ print("Highest F1 samples:")
312
+ for r in by_f1[-3:]:
313
+ print(f" id={r.sample_id} F1={r.selection_f1:.3f} P={r.selection_precision:.3f} R={r.selection_recall:.3f}")
314
+
315
+ if errored:
316
+ print()
317
+ print(f"Errors ({len(errored)}):")
318
+ for r in errored[:5]:
319
+ print(f" id={r.sample_id}: {r.error}")
320
+
321
+ print("=" * 60)
322
+
323
+
324
+ def main(argv=None) -> int:
325
+ ap = argparse.ArgumentParser(description="End-to-end pipeline evaluation")
326
+ ap.add_argument("--n", type=int, default=20, help="Number of samples to evaluate")
327
+ ap.add_argument("--caption-field", default="caption_cogvlm",
328
+ choices=["caption_cogvlm", "caption_llm_0", "caption_llm_1",
329
+ "caption_llm_2", "caption_llm_3", "caption_llm_4",
330
+ "caption_llm_5", "caption_llm_6", "caption_llm_7"],
331
+ help="Which caption field to use as input")
332
+ ap.add_argument("--skip-rewrite", action="store_true",
333
+ help="Skip Stage 1 LLM rewrite; split caption directly into phrases")
334
+ ap.add_argument("--allow-nsfw", action="store_true", help="Allow NSFW tags")
335
+ ap.add_argument("--mode", default="chunked_map_union",
336
+ choices=["single_shot", "chunked_map_union"])
337
+ ap.add_argument("--chunk-size", type=int, default=60)
338
+ ap.add_argument("--per-phrase-k", type=int, default=2)
339
+ ap.add_argument("--temperature", type=float, default=0.0)
340
+ ap.add_argument("--max-tokens", type=int, default=512)
341
+ ap.add_argument("--verbose", "-v", action="store_true", help="Show per-call Stage 3 logs")
342
+ ap.add_argument("--output", "-o", type=str, default=None,
343
+ help="Save detailed results as JSONL to this path")
344
+
345
+ args = ap.parse_args(list(argv) if argv is not None else None)
346
+
347
+ results = run_eval(
348
+ n_samples=args.n,
349
+ caption_field=args.caption_field,
350
+ skip_rewrite=args.skip_rewrite,
351
+ allow_nsfw=args.allow_nsfw,
352
+ mode=args.mode,
353
+ chunk_size=args.chunk_size,
354
+ per_phrase_k=args.per_phrase_k,
355
+ temperature=args.temperature,
356
+ max_tokens=args.max_tokens,
357
+ verbose=args.verbose,
358
+ )
359
+
360
+ print_summary(results)
361
+
362
+ # Optionally save detailed results
363
+ if args.output:
364
+ out_path = Path(args.output)
365
+ out_path.parent.mkdir(parents=True, exist_ok=True)
366
+ with out_path.open("w", encoding="utf-8") as f:
367
+ for r in results:
368
+ row = {
369
+ "sample_id": r.sample_id,
370
+ "caption": r.caption,
371
+ "ground_truth_tags": sorted(r.ground_truth_tags),
372
+ "rewrite_phrases": r.rewrite_phrases,
373
+ "retrieved_tags": sorted(r.retrieved_tags),
374
+ "selected_tags": sorted(r.selected_tags),
375
+ "retrieval_recall": round(r.retrieval_recall, 4),
376
+ "selection_precision": round(r.selection_precision, 4),
377
+ "selection_recall": round(r.selection_recall, 4),
378
+ "selection_f1": round(r.selection_f1, 4),
379
+ "stage1_time": round(r.stage1_time, 3),
380
+ "stage2_time": round(r.stage2_time, 3),
381
+ "stage3_time": round(r.stage3_time, 3),
382
+ "error": r.error,
383
+ }
384
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
385
+ print(f"\nDetailed results saved to: {out_path}")
386
+
387
+ return 0
388
+
389
+
390
+ if __name__ == "__main__":
391
+ sys.exit(main())