Claude commited on
Commit
12dfa28
·
1 Parent(s): f1b4da2

Add parallel processing to eval pipeline with ThreadPoolExecutor

Browse files

Samples are now processed concurrently (default 4 workers) using threads,
which is ideal since the bottleneck is I/O (OpenRouter API calls). Retrieval
assets are pre-warmed before threads start to avoid initialization races.
Use --workers 1 to revert to sequential mode.

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (1) hide show
  1. scripts/eval_pipeline.py +229 -133
scripts/eval_pipeline.py CHANGED
@@ -16,6 +16,12 @@ Usage:
16
  # Reproducible run with specific seed:
17
  python scripts/eval_pipeline.py --n 50 --seed 123
18
 
 
 
 
 
 
 
19
  # Skip Stage 1 LLM rewrite (cheaper, tests Stage 2+3 only):
20
  python scripts/eval_pipeline.py --n 20 --skip-rewrite
21
 
@@ -38,7 +44,9 @@ import json
38
  import os
39
  import random
40
  import sys
 
41
  import time
 
42
  from dataclasses import dataclass, field
43
  from datetime import datetime
44
  from pathlib import Path
@@ -143,6 +151,169 @@ def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float
143
  return precision, recall, f1
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def run_eval(
147
  n_samples: int = 20,
148
  caption_field: str = "caption_cogvlm",
@@ -156,17 +327,9 @@ def run_eval(
156
  verbose: bool = False,
157
  shuffle: bool = True,
158
  seed: int = 42,
 
159
  ) -> List[SampleResult]:
160
 
161
- from psq_rag.llm.rewrite import llm_rewrite_prompt
162
- from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
163
- from psq_rag.llm.select import llm_select_indices
164
- from psq_rag.retrieval.state import get_tag_type_name
165
-
166
- def log(msg: str) -> None:
167
- if verbose:
168
- print(f" {msg}")
169
-
170
  # Load eval samples
171
  if not EVAL_DATA_PATH.is_file():
172
  print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
@@ -196,134 +359,63 @@ def run_eval(
196
 
197
  print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
198
  print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
 
199
  print()
200
 
201
- results: List[SampleResult] = []
202
-
203
- for i, sample in enumerate(samples):
204
- sid = sample["id"]
205
- caption = sample["caption"]
206
- gt_tags = sample["gt_tags"]
207
-
208
- result = SampleResult(
209
- sample_id=sid,
210
- caption=caption[:120] + ("..." if len(caption) > 120 else ""),
211
- ground_truth_tags=gt_tags,
212
- )
213
-
214
- print(f"[{i+1}/{len(samples)}] id={sid} gt_tags={len(gt_tags)}")
215
-
216
- try:
217
- # --- Stage 1: LLM Rewrite ---
218
- if skip_rewrite:
219
- # Use the caption directly as comma-separated phrases
220
- phrases = [p.strip() for p in caption.split(",") if p.strip()]
221
- # Also split on periods/sentences for natural language captions
222
- if len(phrases) <= 1:
223
- phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
224
- result.rewrite_phrases = phrases
225
- result.stage1_time = 0.0
226
- else:
227
- t0 = time.time()
228
- rewritten = llm_rewrite_prompt(caption, log)
229
- result.stage1_time = time.time() - t0
230
- if rewritten:
231
- result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()]
232
- else:
233
- result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()]
234
- if len(result.rewrite_phrases) <= 1:
235
- result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
236
-
237
- if verbose:
238
- log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}")
239
-
240
- # --- Stage 2: Retrieval ---
241
- t0 = time.time()
242
- retrieval_result = psq_candidates_from_rewrite_phrases(
243
- rewrite_phrases=result.rewrite_phrases,
244
- allow_nsfw_tags=allow_nsfw,
245
- global_k=300,
246
- verbose=False,
247
- )
248
- result.stage2_time = time.time() - t0
249
-
250
- if isinstance(retrieval_result, tuple):
251
- candidates, _ = retrieval_result
252
- else:
253
- candidates = retrieval_result
254
-
255
- result.retrieved_tags = {c.tag for c in candidates}
256
- # Retrieval recall: what fraction of ground truth was retrieved
257
- if gt_tags:
258
- result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags)
259
-
260
- if verbose:
261
- log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}")
262
 
263
- # --- Stage 3: LLM Selection ---
264
- t0 = time.time()
265
- picked_indices = llm_select_indices(
266
- query_text=caption,
267
- candidates=candidates,
268
- max_pick=0,
269
- log=log,
270
- mode=mode,
271
- chunk_size=chunk_size,
272
- per_phrase_k=per_phrase_k,
273
- temperature=temperature,
274
- max_tokens=max_tokens,
275
  )
276
- result.stage3_time = time.time() - t0
277
-
278
- result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
279
-
280
- # Overall selection metrics
281
- p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
282
- result.selection_precision = p
283
- result.selection_recall = r
284
- result.selection_f1 = f1
285
-
286
- # Split ground-truth and selected tags by type
287
- gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
288
- sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
289
- ret_char, _ = _classify_tags(result.retrieved_tags, get_tag_type_name)
290
-
291
- result.gt_character_tags = gt_char
292
- result.selected_character_tags = sel_char
293
- result.retrieved_character_tags = ret_char
294
- result.gt_general_tags = gt_gen
295
- result.selected_general_tags = sel_gen
296
-
297
- # Character-specific metrics
298
- if gt_char:
299
- result.char_retrieval_recall = len(ret_char & gt_char) / len(gt_char)
300
- cp, cr, cf1 = _compute_metrics(sel_char, gt_char)
301
- result.char_precision = cp
302
- result.char_recall = cr
303
- result.char_f1 = cf1
304
-
305
- # General-tag metrics
306
- gp, gr, gf1 = _compute_metrics(sel_gen, gt_gen)
307
- result.general_precision = gp
308
- result.general_recall = gr
309
- result.general_f1 = gf1
310
-
311
- # Per-sample output line
312
- char_info = ""
313
- if gt_char:
314
- char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
315
- print(
316
- f" retrieval_recall={result.retrieval_recall:.3f} "
317
- f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
318
- f"selected={len(result.selected_tags)}{char_info} "
319
- f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
320
- )
321
-
322
- except Exception as e:
323
- result.error = str(e)
324
- print(f" ERROR: {e}")
325
-
326
- results.append(result)
327
 
328
  return results
329
 
@@ -506,6 +598,8 @@ def main(argv=None) -> int:
506
  help="Use samples in file order (first N)")
507
  ap.add_argument("--seed", type=int, default=42,
508
  help="Random seed for shuffle (default: 42)")
 
 
509
 
510
  args = ap.parse_args(list(argv) if argv is not None else None)
511
 
@@ -522,6 +616,7 @@ def main(argv=None) -> int:
522
  verbose=args.verbose,
523
  shuffle=args.shuffle,
524
  seed=args.seed,
 
525
  )
526
 
527
  print_summary(results)
@@ -551,6 +646,7 @@ def main(argv=None) -> int:
551
  "temperature": args.temperature,
552
  "shuffle": args.shuffle,
553
  "seed": args.seed,
 
554
  "n_errors": sum(1 for r in results if r.error),
555
  }
556
 
 
16
  # Reproducible run with specific seed:
17
  python scripts/eval_pipeline.py --n 50 --seed 123
18
 
19
+ # Parallel processing with 4 workers (default):
20
+ python scripts/eval_pipeline.py --n 50 --workers 4
21
+
22
+ # Sequential mode (disable parallelism):
23
+ python scripts/eval_pipeline.py --n 20 --workers 1
24
+
25
  # Skip Stage 1 LLM rewrite (cheaper, tests Stage 2+3 only):
26
  python scripts/eval_pipeline.py --n 20 --skip-rewrite
27
 
 
44
  import os
45
  import random
46
  import sys
47
+ import threading
48
  import time
49
+ from concurrent.futures import ThreadPoolExecutor, as_completed
50
  from dataclasses import dataclass, field
51
  from datetime import datetime
52
  from pathlib import Path
 
151
  return precision, recall, f1
152
 
153
 
154
+ def _process_one_sample(
155
+ sample: Dict[str, Any],
156
+ index: int,
157
+ total: int,
158
+ skip_rewrite: bool,
159
+ allow_nsfw: bool,
160
+ mode: str,
161
+ chunk_size: int,
162
+ per_phrase_k: int,
163
+ temperature: float,
164
+ max_tokens: int,
165
+ verbose: bool,
166
+ print_lock: threading.Lock,
167
+ ) -> SampleResult:
168
+ """Process a single eval sample through the full pipeline. Thread-safe."""
169
+ from psq_rag.llm.rewrite import llm_rewrite_prompt
170
+ from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
171
+ from psq_rag.llm.select import llm_select_indices
172
+ from psq_rag.retrieval.state import get_tag_type_name
173
+
174
+ def log(msg: str) -> None:
175
+ if verbose:
176
+ with print_lock:
177
+ print(f" [{index+1}] {msg}")
178
+
179
+ sid = sample["id"]
180
+ caption = sample["caption"]
181
+ gt_tags = sample["gt_tags"]
182
+
183
+ result = SampleResult(
184
+ sample_id=sid,
185
+ caption=caption[:120] + ("..." if len(caption) > 120 else ""),
186
+ ground_truth_tags=gt_tags,
187
+ )
188
+
189
+ with print_lock:
190
+ print(f"[{index+1}/{total}] id={sid} gt_tags={len(gt_tags)}")
191
+
192
+ try:
193
+ # --- Stage 1: LLM Rewrite ---
194
+ if skip_rewrite:
195
+ phrases = [p.strip() for p in caption.split(",") if p.strip()]
196
+ if len(phrases) <= 1:
197
+ phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
198
+ result.rewrite_phrases = phrases
199
+ result.stage1_time = 0.0
200
+ else:
201
+ t0 = time.time()
202
+ rewritten = llm_rewrite_prompt(caption, log)
203
+ result.stage1_time = time.time() - t0
204
+ if rewritten:
205
+ result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()]
206
+ else:
207
+ result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()]
208
+ if len(result.rewrite_phrases) <= 1:
209
+ result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
210
+
211
+ log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}")
212
+
213
+ # --- Stage 2: Retrieval ---
214
+ t0 = time.time()
215
+ retrieval_result = psq_candidates_from_rewrite_phrases(
216
+ rewrite_phrases=result.rewrite_phrases,
217
+ allow_nsfw_tags=allow_nsfw,
218
+ global_k=300,
219
+ verbose=False,
220
+ )
221
+ result.stage2_time = time.time() - t0
222
+
223
+ if isinstance(retrieval_result, tuple):
224
+ candidates, _ = retrieval_result
225
+ else:
226
+ candidates = retrieval_result
227
+
228
+ result.retrieved_tags = {c.tag for c in candidates}
229
+ if gt_tags:
230
+ result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags)
231
+
232
+ log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}")
233
+
234
+ # --- Stage 3: LLM Selection ---
235
+ t0 = time.time()
236
+ picked_indices = llm_select_indices(
237
+ query_text=caption,
238
+ candidates=candidates,
239
+ max_pick=0,
240
+ log=log,
241
+ mode=mode,
242
+ chunk_size=chunk_size,
243
+ per_phrase_k=per_phrase_k,
244
+ temperature=temperature,
245
+ max_tokens=max_tokens,
246
+ )
247
+ result.stage3_time = time.time() - t0
248
+
249
+ result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
250
+
251
+ # Overall selection metrics
252
+ p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
253
+ result.selection_precision = p
254
+ result.selection_recall = r
255
+ result.selection_f1 = f1
256
+
257
+ # Split ground-truth and selected tags by type
258
+ gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
259
+ sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
260
+ ret_char, _ = _classify_tags(result.retrieved_tags, get_tag_type_name)
261
+
262
+ result.gt_character_tags = gt_char
263
+ result.selected_character_tags = sel_char
264
+ result.retrieved_character_tags = ret_char
265
+ result.gt_general_tags = gt_gen
266
+ result.selected_general_tags = sel_gen
267
+
268
+ # Character-specific metrics
269
+ if gt_char:
270
+ result.char_retrieval_recall = len(ret_char & gt_char) / len(gt_char)
271
+ cp, cr, cf1 = _compute_metrics(sel_char, gt_char)
272
+ result.char_precision = cp
273
+ result.char_recall = cr
274
+ result.char_f1 = cf1
275
+
276
+ # General-tag metrics
277
+ gp, gr, gf1 = _compute_metrics(sel_gen, gt_gen)
278
+ result.general_precision = gp
279
+ result.general_recall = gr
280
+ result.general_f1 = gf1
281
+
282
+ # Per-sample output line
283
+ char_info = ""
284
+ if gt_char:
285
+ char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
286
+ with print_lock:
287
+ print(
288
+ f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
289
+ f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
290
+ f"selected={len(result.selected_tags)}{char_info} "
291
+ f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
292
+ )
293
+
294
+ except Exception as e:
295
+ result.error = str(e)
296
+ with print_lock:
297
+ print(f" [{index+1}] ERROR: {e}")
298
+
299
+ return result
300
+
301
+
302
+ def _prewarm_retrieval_assets() -> None:
303
+ """Force-load all lazy retrieval assets so threads don't race on init."""
304
+ from psq_rag.retrieval.state import (
305
+ get_tfidf_components,
306
+ get_tag2aliases,
307
+ get_tag_type_name,
308
+ )
309
+ print("Pre-warming retrieval assets (TF-IDF, FastText, HNSW, aliases)...")
310
+ t0 = time.time()
311
+ get_tfidf_components() # loads joblib, HNSW indexes, FastText model
312
+ get_tag2aliases() # loads CSV alias dict
313
+ get_tag_type_name("_warmup_") # ensures tag type dict is built
314
+ print(f" Assets loaded in {time.time() - t0:.1f}s")
315
+
316
+
317
  def run_eval(
318
  n_samples: int = 20,
319
  caption_field: str = "caption_cogvlm",
 
327
  verbose: bool = False,
328
  shuffle: bool = True,
329
  seed: int = 42,
330
+ workers: int = 1,
331
  ) -> List[SampleResult]:
332
 
 
 
 
 
 
 
 
 
 
333
  # Load eval samples
334
  if not EVAL_DATA_PATH.is_file():
335
  print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
 
359
 
360
  print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
361
  print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
362
+ print(f"workers={workers}")
363
  print()
364
 
365
+ # Pre-warm shared retrieval assets before spawning threads
366
+ _prewarm_retrieval_assets()
367
+ print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ print_lock = threading.Lock()
370
+ total = len(samples)
371
+
372
+ if workers <= 1:
373
+ # Sequential mode (original behavior)
374
+ results: List[SampleResult] = []
375
+ for i, sample in enumerate(samples):
376
+ result = _process_one_sample(
377
+ sample, i, total,
378
+ skip_rewrite, allow_nsfw, mode, chunk_size,
379
+ per_phrase_k, temperature, max_tokens, verbose,
380
+ print_lock,
381
  )
382
+ results.append(result)
383
+ else:
384
+ # Parallel mode
385
+ print(f"Processing {total} samples with {workers} parallel workers...")
386
+ print()
387
+ # Submit all samples; use index to preserve original ordering
388
+ results_by_index: Dict[int, SampleResult] = {}
389
+ with ThreadPoolExecutor(max_workers=workers) as executor:
390
+ futures = {
391
+ executor.submit(
392
+ _process_one_sample,
393
+ sample, i, total,
394
+ skip_rewrite, allow_nsfw, mode, chunk_size,
395
+ per_phrase_k, temperature, max_tokens, verbose,
396
+ print_lock,
397
+ ): i
398
+ for i, sample in enumerate(samples)
399
+ }
400
+ for future in as_completed(futures):
401
+ idx = futures[future]
402
+ try:
403
+ results_by_index[idx] = future.result()
404
+ except Exception as e:
405
+ # Should not happen since _process_one_sample catches exceptions,
406
+ # but guard against unexpected errors
407
+ with print_lock:
408
+ print(f" [{idx+1}] WORKER ERROR: {e}")
409
+ result = SampleResult(
410
+ sample_id=samples[idx]["id"],
411
+ caption=samples[idx]["caption"][:120],
412
+ ground_truth_tags=samples[idx]["gt_tags"],
413
+ error=f"Worker error: {e}",
414
+ )
415
+ results_by_index[idx] = result
416
+
417
+ # Reassemble in original order
418
+ results = [results_by_index[i] for i in range(total)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  return results
421
 
 
598
  help="Use samples in file order (first N)")
599
  ap.add_argument("--seed", type=int, default=42,
600
  help="Random seed for shuffle (default: 42)")
601
+ ap.add_argument("--workers", "-w", type=int, default=4,
602
+ help="Number of parallel workers (default: 4, use 1 for sequential)")
603
 
604
  args = ap.parse_args(list(argv) if argv is not None else None)
605
 
 
616
  verbose=args.verbose,
617
  shuffle=args.shuffle,
618
  seed=args.seed,
619
+ workers=args.workers,
620
  )
621
 
622
  print_summary(results)
 
646
  "temperature": args.temperature,
647
  "shuffle": args.shuffle,
648
  "seed": args.seed,
649
+ "workers": args.workers,
650
  "n_errors": sum(1 for r in results if r.error),
651
  }
652