Spaces:
Running
Running
Claude commited on
Commit ·
12dfa28
1
Parent(s): f1b4da2
Add parallel processing to eval pipeline with ThreadPoolExecutor
Browse filesSamples are now processed concurrently (default 4 workers) using threads,
which is ideal since the bottleneck is I/O (OpenRouter API calls). Retrieval
assets are pre-warmed before threads start to avoid initialization races.
Use --workers 1 to revert to sequential mode.
https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG
- scripts/eval_pipeline.py +229 -133
scripts/eval_pipeline.py
CHANGED
|
@@ -16,6 +16,12 @@ Usage:
|
|
| 16 |
# Reproducible run with specific seed:
|
| 17 |
python scripts/eval_pipeline.py --n 50 --seed 123
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Skip Stage 1 LLM rewrite (cheaper, tests Stage 2+3 only):
|
| 20 |
python scripts/eval_pipeline.py --n 20 --skip-rewrite
|
| 21 |
|
|
@@ -38,7 +44,9 @@ import json
|
|
| 38 |
import os
|
| 39 |
import random
|
| 40 |
import sys
|
|
|
|
| 41 |
import time
|
|
|
|
| 42 |
from dataclasses import dataclass, field
|
| 43 |
from datetime import datetime
|
| 44 |
from pathlib import Path
|
|
@@ -143,6 +151,169 @@ def _compute_metrics(predicted: Set[str], ground_truth: Set[str]) -> Tuple[float
|
|
| 143 |
return precision, recall, f1
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
def run_eval(
|
| 147 |
n_samples: int = 20,
|
| 148 |
caption_field: str = "caption_cogvlm",
|
|
@@ -156,17 +327,9 @@ def run_eval(
|
|
| 156 |
verbose: bool = False,
|
| 157 |
shuffle: bool = True,
|
| 158 |
seed: int = 42,
|
|
|
|
| 159 |
) -> List[SampleResult]:
|
| 160 |
|
| 161 |
-
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 162 |
-
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 163 |
-
from psq_rag.llm.select import llm_select_indices
|
| 164 |
-
from psq_rag.retrieval.state import get_tag_type_name
|
| 165 |
-
|
| 166 |
-
def log(msg: str) -> None:
|
| 167 |
-
if verbose:
|
| 168 |
-
print(f" {msg}")
|
| 169 |
-
|
| 170 |
# Load eval samples
|
| 171 |
if not EVAL_DATA_PATH.is_file():
|
| 172 |
print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
|
|
@@ -196,134 +359,63 @@ def run_eval(
|
|
| 196 |
|
| 197 |
print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
|
| 198 |
print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
|
|
|
|
| 199 |
print()
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
sid = sample["id"]
|
| 205 |
-
caption = sample["caption"]
|
| 206 |
-
gt_tags = sample["gt_tags"]
|
| 207 |
-
|
| 208 |
-
result = SampleResult(
|
| 209 |
-
sample_id=sid,
|
| 210 |
-
caption=caption[:120] + ("..." if len(caption) > 120 else ""),
|
| 211 |
-
ground_truth_tags=gt_tags,
|
| 212 |
-
)
|
| 213 |
-
|
| 214 |
-
print(f"[{i+1}/{len(samples)}] id={sid} gt_tags={len(gt_tags)}")
|
| 215 |
-
|
| 216 |
-
try:
|
| 217 |
-
# --- Stage 1: LLM Rewrite ---
|
| 218 |
-
if skip_rewrite:
|
| 219 |
-
# Use the caption directly as comma-separated phrases
|
| 220 |
-
phrases = [p.strip() for p in caption.split(",") if p.strip()]
|
| 221 |
-
# Also split on periods/sentences for natural language captions
|
| 222 |
-
if len(phrases) <= 1:
|
| 223 |
-
phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
|
| 224 |
-
result.rewrite_phrases = phrases
|
| 225 |
-
result.stage1_time = 0.0
|
| 226 |
-
else:
|
| 227 |
-
t0 = time.time()
|
| 228 |
-
rewritten = llm_rewrite_prompt(caption, log)
|
| 229 |
-
result.stage1_time = time.time() - t0
|
| 230 |
-
if rewritten:
|
| 231 |
-
result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()]
|
| 232 |
-
else:
|
| 233 |
-
result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()]
|
| 234 |
-
if len(result.rewrite_phrases) <= 1:
|
| 235 |
-
result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
|
| 236 |
-
|
| 237 |
-
if verbose:
|
| 238 |
-
log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}")
|
| 239 |
-
|
| 240 |
-
# --- Stage 2: Retrieval ---
|
| 241 |
-
t0 = time.time()
|
| 242 |
-
retrieval_result = psq_candidates_from_rewrite_phrases(
|
| 243 |
-
rewrite_phrases=result.rewrite_phrases,
|
| 244 |
-
allow_nsfw_tags=allow_nsfw,
|
| 245 |
-
global_k=300,
|
| 246 |
-
verbose=False,
|
| 247 |
-
)
|
| 248 |
-
result.stage2_time = time.time() - t0
|
| 249 |
-
|
| 250 |
-
if isinstance(retrieval_result, tuple):
|
| 251 |
-
candidates, _ = retrieval_result
|
| 252 |
-
else:
|
| 253 |
-
candidates = retrieval_result
|
| 254 |
-
|
| 255 |
-
result.retrieved_tags = {c.tag for c in candidates}
|
| 256 |
-
# Retrieval recall: what fraction of ground truth was retrieved
|
| 257 |
-
if gt_tags:
|
| 258 |
-
result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags)
|
| 259 |
-
|
| 260 |
-
if verbose:
|
| 261 |
-
log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}")
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
temperature
|
| 274 |
-
|
| 275 |
)
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
if gt_char:
|
| 314 |
-
char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
|
| 315 |
-
print(
|
| 316 |
-
f" retrieval_recall={result.retrieval_recall:.3f} "
|
| 317 |
-
f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
|
| 318 |
-
f"selected={len(result.selected_tags)}{char_info} "
|
| 319 |
-
f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
except Exception as e:
|
| 323 |
-
result.error = str(e)
|
| 324 |
-
print(f" ERROR: {e}")
|
| 325 |
-
|
| 326 |
-
results.append(result)
|
| 327 |
|
| 328 |
return results
|
| 329 |
|
|
@@ -506,6 +598,8 @@ def main(argv=None) -> int:
|
|
| 506 |
help="Use samples in file order (first N)")
|
| 507 |
ap.add_argument("--seed", type=int, default=42,
|
| 508 |
help="Random seed for shuffle (default: 42)")
|
|
|
|
|
|
|
| 509 |
|
| 510 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 511 |
|
|
@@ -522,6 +616,7 @@ def main(argv=None) -> int:
|
|
| 522 |
verbose=args.verbose,
|
| 523 |
shuffle=args.shuffle,
|
| 524 |
seed=args.seed,
|
|
|
|
| 525 |
)
|
| 526 |
|
| 527 |
print_summary(results)
|
|
@@ -551,6 +646,7 @@ def main(argv=None) -> int:
|
|
| 551 |
"temperature": args.temperature,
|
| 552 |
"shuffle": args.shuffle,
|
| 553 |
"seed": args.seed,
|
|
|
|
| 554 |
"n_errors": sum(1 for r in results if r.error),
|
| 555 |
}
|
| 556 |
|
|
|
|
| 16 |
# Reproducible run with specific seed:
|
| 17 |
python scripts/eval_pipeline.py --n 50 --seed 123
|
| 18 |
|
| 19 |
+
# Parallel processing with 4 workers (default):
|
| 20 |
+
python scripts/eval_pipeline.py --n 50 --workers 4
|
| 21 |
+
|
| 22 |
+
# Sequential mode (disable parallelism):
|
| 23 |
+
python scripts/eval_pipeline.py --n 20 --workers 1
|
| 24 |
+
|
| 25 |
# Skip Stage 1 LLM rewrite (cheaper, tests Stage 2+3 only):
|
| 26 |
python scripts/eval_pipeline.py --n 20 --skip-rewrite
|
| 27 |
|
|
|
|
| 44 |
import os
|
| 45 |
import random
|
| 46 |
import sys
|
| 47 |
+
import threading
|
| 48 |
import time
|
| 49 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 50 |
from dataclasses import dataclass, field
|
| 51 |
from datetime import datetime
|
| 52 |
from pathlib import Path
|
|
|
|
| 151 |
return precision, recall, f1
|
| 152 |
|
| 153 |
|
| 154 |
+
def _process_one_sample(
|
| 155 |
+
sample: Dict[str, Any],
|
| 156 |
+
index: int,
|
| 157 |
+
total: int,
|
| 158 |
+
skip_rewrite: bool,
|
| 159 |
+
allow_nsfw: bool,
|
| 160 |
+
mode: str,
|
| 161 |
+
chunk_size: int,
|
| 162 |
+
per_phrase_k: int,
|
| 163 |
+
temperature: float,
|
| 164 |
+
max_tokens: int,
|
| 165 |
+
verbose: bool,
|
| 166 |
+
print_lock: threading.Lock,
|
| 167 |
+
) -> SampleResult:
|
| 168 |
+
"""Process a single eval sample through the full pipeline. Thread-safe."""
|
| 169 |
+
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 170 |
+
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 171 |
+
from psq_rag.llm.select import llm_select_indices
|
| 172 |
+
from psq_rag.retrieval.state import get_tag_type_name
|
| 173 |
+
|
| 174 |
+
def log(msg: str) -> None:
|
| 175 |
+
if verbose:
|
| 176 |
+
with print_lock:
|
| 177 |
+
print(f" [{index+1}] {msg}")
|
| 178 |
+
|
| 179 |
+
sid = sample["id"]
|
| 180 |
+
caption = sample["caption"]
|
| 181 |
+
gt_tags = sample["gt_tags"]
|
| 182 |
+
|
| 183 |
+
result = SampleResult(
|
| 184 |
+
sample_id=sid,
|
| 185 |
+
caption=caption[:120] + ("..." if len(caption) > 120 else ""),
|
| 186 |
+
ground_truth_tags=gt_tags,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
with print_lock:
|
| 190 |
+
print(f"[{index+1}/{total}] id={sid} gt_tags={len(gt_tags)}")
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
# --- Stage 1: LLM Rewrite ---
|
| 194 |
+
if skip_rewrite:
|
| 195 |
+
phrases = [p.strip() for p in caption.split(",") if p.strip()]
|
| 196 |
+
if len(phrases) <= 1:
|
| 197 |
+
phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
|
| 198 |
+
result.rewrite_phrases = phrases
|
| 199 |
+
result.stage1_time = 0.0
|
| 200 |
+
else:
|
| 201 |
+
t0 = time.time()
|
| 202 |
+
rewritten = llm_rewrite_prompt(caption, log)
|
| 203 |
+
result.stage1_time = time.time() - t0
|
| 204 |
+
if rewritten:
|
| 205 |
+
result.rewrite_phrases = [p.strip() for p in rewritten.split(",") if p.strip()]
|
| 206 |
+
else:
|
| 207 |
+
result.rewrite_phrases = [p.strip() for p in caption.split(",") if p.strip()]
|
| 208 |
+
if len(result.rewrite_phrases) <= 1:
|
| 209 |
+
result.rewrite_phrases = [p.strip() for p in caption.replace(".", ",").split(",") if p.strip()]
|
| 210 |
+
|
| 211 |
+
log(f"Phrases ({len(result.rewrite_phrases)}): {result.rewrite_phrases[:5]}")
|
| 212 |
+
|
| 213 |
+
# --- Stage 2: Retrieval ---
|
| 214 |
+
t0 = time.time()
|
| 215 |
+
retrieval_result = psq_candidates_from_rewrite_phrases(
|
| 216 |
+
rewrite_phrases=result.rewrite_phrases,
|
| 217 |
+
allow_nsfw_tags=allow_nsfw,
|
| 218 |
+
global_k=300,
|
| 219 |
+
verbose=False,
|
| 220 |
+
)
|
| 221 |
+
result.stage2_time = time.time() - t0
|
| 222 |
+
|
| 223 |
+
if isinstance(retrieval_result, tuple):
|
| 224 |
+
candidates, _ = retrieval_result
|
| 225 |
+
else:
|
| 226 |
+
candidates = retrieval_result
|
| 227 |
+
|
| 228 |
+
result.retrieved_tags = {c.tag for c in candidates}
|
| 229 |
+
if gt_tags:
|
| 230 |
+
result.retrieval_recall = len(result.retrieved_tags & gt_tags) / len(gt_tags)
|
| 231 |
+
|
| 232 |
+
log(f"Retrieved {len(candidates)} candidates, recall={result.retrieval_recall:.3f}")
|
| 233 |
+
|
| 234 |
+
# --- Stage 3: LLM Selection ---
|
| 235 |
+
t0 = time.time()
|
| 236 |
+
picked_indices = llm_select_indices(
|
| 237 |
+
query_text=caption,
|
| 238 |
+
candidates=candidates,
|
| 239 |
+
max_pick=0,
|
| 240 |
+
log=log,
|
| 241 |
+
mode=mode,
|
| 242 |
+
chunk_size=chunk_size,
|
| 243 |
+
per_phrase_k=per_phrase_k,
|
| 244 |
+
temperature=temperature,
|
| 245 |
+
max_tokens=max_tokens,
|
| 246 |
+
)
|
| 247 |
+
result.stage3_time = time.time() - t0
|
| 248 |
+
|
| 249 |
+
result.selected_tags = {candidates[idx].tag for idx in picked_indices} if picked_indices else set()
|
| 250 |
+
|
| 251 |
+
# Overall selection metrics
|
| 252 |
+
p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
|
| 253 |
+
result.selection_precision = p
|
| 254 |
+
result.selection_recall = r
|
| 255 |
+
result.selection_f1 = f1
|
| 256 |
+
|
| 257 |
+
# Split ground-truth and selected tags by type
|
| 258 |
+
gt_char, gt_gen = _classify_tags(gt_tags, get_tag_type_name)
|
| 259 |
+
sel_char, sel_gen = _classify_tags(result.selected_tags, get_tag_type_name)
|
| 260 |
+
ret_char, _ = _classify_tags(result.retrieved_tags, get_tag_type_name)
|
| 261 |
+
|
| 262 |
+
result.gt_character_tags = gt_char
|
| 263 |
+
result.selected_character_tags = sel_char
|
| 264 |
+
result.retrieved_character_tags = ret_char
|
| 265 |
+
result.gt_general_tags = gt_gen
|
| 266 |
+
result.selected_general_tags = sel_gen
|
| 267 |
+
|
| 268 |
+
# Character-specific metrics
|
| 269 |
+
if gt_char:
|
| 270 |
+
result.char_retrieval_recall = len(ret_char & gt_char) / len(gt_char)
|
| 271 |
+
cp, cr, cf1 = _compute_metrics(sel_char, gt_char)
|
| 272 |
+
result.char_precision = cp
|
| 273 |
+
result.char_recall = cr
|
| 274 |
+
result.char_f1 = cf1
|
| 275 |
+
|
| 276 |
+
# General-tag metrics
|
| 277 |
+
gp, gr, gf1 = _compute_metrics(sel_gen, gt_gen)
|
| 278 |
+
result.general_precision = gp
|
| 279 |
+
result.general_recall = gr
|
| 280 |
+
result.general_f1 = gf1
|
| 281 |
+
|
| 282 |
+
# Per-sample output line
|
| 283 |
+
char_info = ""
|
| 284 |
+
if gt_char:
|
| 285 |
+
char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
|
| 286 |
+
with print_lock:
|
| 287 |
+
print(
|
| 288 |
+
f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
|
| 289 |
+
f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
|
| 290 |
+
f"selected={len(result.selected_tags)}{char_info} "
|
| 291 |
+
f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
except Exception as e:
|
| 295 |
+
result.error = str(e)
|
| 296 |
+
with print_lock:
|
| 297 |
+
print(f" [{index+1}] ERROR: {e}")
|
| 298 |
+
|
| 299 |
+
return result
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _prewarm_retrieval_assets() -> None:
|
| 303 |
+
"""Force-load all lazy retrieval assets so threads don't race on init."""
|
| 304 |
+
from psq_rag.retrieval.state import (
|
| 305 |
+
get_tfidf_components,
|
| 306 |
+
get_tag2aliases,
|
| 307 |
+
get_tag_type_name,
|
| 308 |
+
)
|
| 309 |
+
print("Pre-warming retrieval assets (TF-IDF, FastText, HNSW, aliases)...")
|
| 310 |
+
t0 = time.time()
|
| 311 |
+
get_tfidf_components() # loads joblib, HNSW indexes, FastText model
|
| 312 |
+
get_tag2aliases() # loads CSV alias dict
|
| 313 |
+
get_tag_type_name("_warmup_") # ensures tag type dict is built
|
| 314 |
+
print(f" Assets loaded in {time.time() - t0:.1f}s")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
def run_eval(
|
| 318 |
n_samples: int = 20,
|
| 319 |
caption_field: str = "caption_cogvlm",
|
|
|
|
| 327 |
verbose: bool = False,
|
| 328 |
shuffle: bool = True,
|
| 329 |
seed: int = 42,
|
| 330 |
+
workers: int = 1,
|
| 331 |
) -> List[SampleResult]:
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
# Load eval samples
|
| 334 |
if not EVAL_DATA_PATH.is_file():
|
| 335 |
print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
|
|
|
|
| 359 |
|
| 360 |
print(f"Loaded {len(samples)}/{len(all_samples)} samples (caption_field={caption_field})")
|
| 361 |
print(f"shuffle={shuffle}, seed={seed}, skip_rewrite={skip_rewrite}, allow_nsfw={allow_nsfw}, mode={mode}")
|
| 362 |
+
print(f"workers={workers}")
|
| 363 |
print()
|
| 364 |
|
| 365 |
+
# Pre-warm shared retrieval assets before spawning threads
|
| 366 |
+
_prewarm_retrieval_assets()
|
| 367 |
+
print()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
print_lock = threading.Lock()
|
| 370 |
+
total = len(samples)
|
| 371 |
+
|
| 372 |
+
if workers <= 1:
|
| 373 |
+
# Sequential mode (original behavior)
|
| 374 |
+
results: List[SampleResult] = []
|
| 375 |
+
for i, sample in enumerate(samples):
|
| 376 |
+
result = _process_one_sample(
|
| 377 |
+
sample, i, total,
|
| 378 |
+
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 379 |
+
per_phrase_k, temperature, max_tokens, verbose,
|
| 380 |
+
print_lock,
|
| 381 |
)
|
| 382 |
+
results.append(result)
|
| 383 |
+
else:
|
| 384 |
+
# Parallel mode
|
| 385 |
+
print(f"Processing {total} samples with {workers} parallel workers...")
|
| 386 |
+
print()
|
| 387 |
+
# Submit all samples; use index to preserve original ordering
|
| 388 |
+
results_by_index: Dict[int, SampleResult] = {}
|
| 389 |
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
| 390 |
+
futures = {
|
| 391 |
+
executor.submit(
|
| 392 |
+
_process_one_sample,
|
| 393 |
+
sample, i, total,
|
| 394 |
+
skip_rewrite, allow_nsfw, mode, chunk_size,
|
| 395 |
+
per_phrase_k, temperature, max_tokens, verbose,
|
| 396 |
+
print_lock,
|
| 397 |
+
): i
|
| 398 |
+
for i, sample in enumerate(samples)
|
| 399 |
+
}
|
| 400 |
+
for future in as_completed(futures):
|
| 401 |
+
idx = futures[future]
|
| 402 |
+
try:
|
| 403 |
+
results_by_index[idx] = future.result()
|
| 404 |
+
except Exception as e:
|
| 405 |
+
# Should not happen since _process_one_sample catches exceptions,
|
| 406 |
+
# but guard against unexpected errors
|
| 407 |
+
with print_lock:
|
| 408 |
+
print(f" [{idx+1}] WORKER ERROR: {e}")
|
| 409 |
+
result = SampleResult(
|
| 410 |
+
sample_id=samples[idx]["id"],
|
| 411 |
+
caption=samples[idx]["caption"][:120],
|
| 412 |
+
ground_truth_tags=samples[idx]["gt_tags"],
|
| 413 |
+
error=f"Worker error: {e}",
|
| 414 |
+
)
|
| 415 |
+
results_by_index[idx] = result
|
| 416 |
+
|
| 417 |
+
# Reassemble in original order
|
| 418 |
+
results = [results_by_index[i] for i in range(total)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
return results
|
| 421 |
|
|
|
|
| 598 |
help="Use samples in file order (first N)")
|
| 599 |
ap.add_argument("--seed", type=int, default=42,
|
| 600 |
help="Random seed for shuffle (default: 42)")
|
| 601 |
+
ap.add_argument("--workers", "-w", type=int, default=4,
|
| 602 |
+
help="Number of parallel workers (default: 4, use 1 for sequential)")
|
| 603 |
|
| 604 |
args = ap.parse_args(list(argv) if argv is not None else None)
|
| 605 |
|
|
|
|
| 616 |
verbose=args.verbose,
|
| 617 |
shuffle=args.shuffle,
|
| 618 |
seed=args.seed,
|
| 619 |
+
workers=args.workers,
|
| 620 |
)
|
| 621 |
|
| 622 |
print_summary(results)
|
|
|
|
| 646 |
"temperature": args.temperature,
|
| 647 |
"shuffle": args.shuffle,
|
| 648 |
"seed": args.seed,
|
| 649 |
+
"workers": args.workers,
|
| 650 |
"n_errors": sum(1 for r in results if r.error),
|
| 651 |
}
|
| 652 |
|