Claude commited on
Commit
eeada1d
·
1 Parent(s): de8b5a3

Add tag implication expansion (fox→canine→canid→mammal)

Browse files

Walks the e621 implication graph upward from each selected tag to add
ancestor taxonomy tags that were structurally unreachable by retrieval.

- state.py: get_tag_implications() loads+caches the directed graph,
expand_tags_via_implications() BFS-walks ancestors from a tag set
- app.py: expansion runs between Stage 3 and Stage 4 (always on)
- eval_pipeline.py: --expand-implications flag for controlled eval

https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG

Files changed (3) hide show
  1. app.py +10 -0
  2. psq_rag/retrieval/state.py +54 -0
  3. scripts/eval_pipeline.py +27 -4
app.py CHANGED
@@ -9,6 +9,7 @@ from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
9
  from psq_rag.llm.rewrite import llm_rewrite_prompt
10
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
11
  from psq_rag.llm.select import llm_select_indices
 
12
 
13
 
14
  def _split_prompt_commas(s: str) -> List[str]:
@@ -223,6 +224,15 @@ def rag_pipeline_ui(user_prompt: str):
223
 
224
  selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
225
 
 
 
 
 
 
 
 
 
 
226
  log("Step 4: Compose final prompt")
227
  final_prompt = compose_final_prompt(rewritten, selected_tags)
228
 
 
9
  from psq_rag.llm.rewrite import llm_rewrite_prompt
10
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
11
  from psq_rag.llm.select import llm_select_indices
12
+ from psq_rag.retrieval.state import expand_tags_via_implications
13
 
14
 
15
  def _split_prompt_commas(s: str) -> List[str]:
 
224
 
225
  selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
226
 
227
+ log("Step 3b: Expand via tag implications")
228
+ tag_set = set(selected_tags)
229
+ expanded, implied_only = expand_tags_via_implications(tag_set)
230
+ if implied_only:
231
+ selected_tags.extend(sorted(implied_only))
232
+ log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}")
233
+ else:
234
+ log(" No additional implied tags")
235
+
236
  log("Step 4: Compose final prompt")
237
  final_prompt = compose_final_prompt(rewritten, selected_tags)
238
 
psq_rag/retrieval/state.py CHANGED
@@ -22,6 +22,7 @@ HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
22
  HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
23
  FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin")
24
  TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv")
 
25
 
26
  _tfidf_components: Optional[Dict[str, Any]] = None
27
  _nsfw_tags: Optional[Set[str]] = None
@@ -32,6 +33,7 @@ _tfidf_tag_vectors: Optional[Dict[str, Any]] = None
32
  _alias_to_tags: Optional[Dict[str, List[str]]] = None
33
  _tag_to_aliases: Optional[Dict[str, List[str]]] = None
34
  _tag_type_id: Optional[Dict[str, int]] = None
 
35
 
36
 
37
  _hnsw_tag_index: Optional["hnswlib.Index"] = None
@@ -273,6 +275,58 @@ def get_tag2aliases() -> Dict[str, List[str]]:
273
  return _tag_to_aliases
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  def get_tfidf_tag_vectors() -> Dict[str, Any]:
277
  global _tfidf_tag_vectors
278
  if _tfidf_tag_vectors is not None:
 
22
  HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
23
  FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin")
24
  TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv")
25
+ TAG_IMPLICATIONS_PATH = pathlib.Path("tag_implications-2023-07-20.csv")
26
 
27
  _tfidf_components: Optional[Dict[str, Any]] = None
28
  _nsfw_tags: Optional[Set[str]] = None
 
33
  _alias_to_tags: Optional[Dict[str, List[str]]] = None
34
  _tag_to_aliases: Optional[Dict[str, List[str]]] = None
35
  _tag_type_id: Optional[Dict[str, int]] = None
36
+ _tag_implications: Optional[Dict[str, List[str]]] = None
37
 
38
 
39
  _hnsw_tag_index: Optional["hnswlib.Index"] = None
 
275
  return _tag_to_aliases
276
 
277
 
278
+ def get_tag_implications() -> Dict[str, List[str]]:
279
+ """Return antecedent_tag -> [consequent_tags] from the implications CSV.
280
+
281
+ Only active implications where both tags exist in the tag database are kept.
282
+ """
283
+ global _tag_implications
284
+ if _tag_implications is not None:
285
+ return _tag_implications
286
+
287
+ if not TAG_IMPLICATIONS_PATH.is_file():
288
+ logging.warning("Tag implications CSV not found: %s", TAG_IMPLICATIONS_PATH)
289
+ _tag_implications = {}
290
+ return _tag_implications
291
+
292
+ known_tags = set(get_tag_type_ids().keys())
293
+ impl: Dict[str, List[str]] = {}
294
+ with TAG_IMPLICATIONS_PATH.open("r", newline="", encoding="utf-8") as csvfile:
295
+ reader = csv.reader(csvfile)
296
+ next(reader, None) # skip header
297
+ for row in reader:
298
+ if len(row) < 5 or row[4] != "active":
299
+ continue
300
+ antecedent = clean_tag(row[1])
301
+ consequent = clean_tag(row[2])
302
+ if antecedent in known_tags and consequent in known_tags:
303
+ impl.setdefault(antecedent, []).append(consequent)
304
+
305
+ _tag_implications = impl
306
+ logging.info("Loaded %d tag implications", sum(len(v) for v in impl.values()))
307
+ return _tag_implications
308
+
309
+
310
+ def expand_tags_via_implications(tags: Set[str]) -> Tuple[Set[str], Set[str]]:
311
+ """Walk the implication graph upward from each tag, collecting ancestors.
312
+
313
+ Returns (all_tags, implied_only) where:
314
+ - all_tags = original tags + implied ancestors
315
+ - implied_only = tags that were added (not in the original set)
316
+ """
317
+ impl = get_tag_implications()
318
+ expanded = set(tags)
319
+ queue = list(tags)
320
+ while queue:
321
+ tag = queue.pop()
322
+ for parent in impl.get(tag, ()):
323
+ if parent not in expanded:
324
+ expanded.add(parent)
325
+ queue.append(parent)
326
+ implied_only = expanded - tags
327
+ return expanded, implied_only
328
+
329
+
330
  def get_tfidf_tag_vectors() -> Dict[str, Any]:
331
  global _tfidf_tag_vectors
332
  if _tfidf_tag_vectors is not None:
scripts/eval_pipeline.py CHANGED
@@ -133,6 +133,8 @@ class SampleResult:
133
  over_selection_ratio: float = 0.0 # |selected| / |gt|
134
  # Why distribution (from Stage 3 LLM)
135
  why_counts: Dict[str, int] = field(default_factory=dict)
 
 
136
  # Timing
137
  stage1_time: float = 0.0
138
  stage2_time: float = 0.0
@@ -171,12 +173,13 @@ def _process_one_sample(
171
  verbose: bool,
172
  print_lock: threading.Lock,
173
  min_why: Optional[str] = None,
 
174
  ) -> SampleResult:
175
  """Process a single eval sample through the full pipeline. Thread-safe."""
176
  from psq_rag.llm.rewrite import llm_rewrite_prompt
177
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
178
  from psq_rag.llm.select import llm_select_indices
179
- from psq_rag.retrieval.state import get_tag_type_name
180
 
181
  def log(msg: str) -> None:
182
  if verbose:
@@ -263,6 +266,13 @@ def _process_one_sample(
263
  why_counts[w] = why_counts.get(w, 0) + 1
264
  result.why_counts = why_counts
265
 
 
 
 
 
 
 
 
266
  # Overall selection metrics
267
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
268
  result.selection_precision = p
@@ -308,11 +318,12 @@ def _process_one_sample(
308
  char_info = ""
309
  if gt_char:
310
  char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
 
311
  with print_lock:
312
  print(
313
  f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
314
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
315
- f"selected={len(result.selected_tags)}{char_info} "
316
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
317
  )
318
 
@@ -330,12 +341,14 @@ def _prewarm_retrieval_assets() -> None:
330
  get_tfidf_components,
331
  get_tag2aliases,
332
  get_tag_type_name,
 
333
  )
334
  print("Pre-warming retrieval assets (TF-IDF, FastText, HNSW, aliases)...")
335
  t0 = time.time()
336
  get_tfidf_components() # loads joblib, HNSW indexes, FastText model
337
  get_tag2aliases() # loads CSV alias dict
338
  get_tag_type_name("_warmup_") # ensures tag type dict is built
 
339
  print(f" Assets loaded in {time.time() - t0:.1f}s")
340
 
341
 
@@ -354,6 +367,7 @@ def run_eval(
354
  seed: int = 42,
355
  workers: int = 1,
356
  min_why: Optional[str] = None,
 
357
  ) -> List[SampleResult]:
358
 
359
  # Load eval samples
@@ -403,7 +417,7 @@ def run_eval(
403
  sample, i, total,
404
  skip_rewrite, allow_nsfw, mode, chunk_size,
405
  per_phrase_k, temperature, max_tokens, verbose,
406
- print_lock, min_why,
407
  )
408
  results.append(result)
409
  else:
@@ -419,7 +433,7 @@ def run_eval(
419
  sample, i, total,
420
  skip_rewrite, allow_nsfw, mode, chunk_size,
421
  per_phrase_k, temperature, max_tokens, verbose,
422
- print_lock, min_why,
423
  ): i
424
  for i, sample in enumerate(samples)
425
  }
@@ -487,12 +501,16 @@ def print_summary(results: List[SampleResult]) -> None:
487
  if (r.retrieved_tags & r.ground_truth_tags)])
488
  avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
489
 
 
 
490
  print()
491
  print("Stage 3 - Selection (ALL tags):")
492
  print(f" Avg precision: {avg_sel_precision:.4f}")
493
  print(f" Avg recall: {avg_sel_recall:.4f}")
494
  print(f" Avg F1: {avg_sel_f1:.4f}")
495
  print(f" Avg selected tags: {avg_selected:.1f}")
 
 
496
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
497
  print()
498
  print("Diagnostic Metrics:")
@@ -653,6 +671,8 @@ def main(argv=None) -> int:
653
  ap.add_argument("--min-why", default=None,
654
  choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"],
655
  help="Minimum 'why' confidence to keep (e.g. 'explicit' keeps only explicit matches)")
 
 
656
 
657
  args = ap.parse_args(list(argv) if argv is not None else None)
658
 
@@ -671,6 +691,7 @@ def main(argv=None) -> int:
671
  seed=args.seed,
672
  workers=args.workers,
673
  min_why=args.min_why,
 
674
  )
675
 
676
  print_summary(results)
@@ -702,6 +723,7 @@ def main(argv=None) -> int:
702
  "seed": args.seed,
703
  "workers": args.workers,
704
  "min_why": args.min_why,
 
705
  "n_errors": sum(1 for r in results if r.error),
706
  }
707
 
@@ -738,6 +760,7 @@ def main(argv=None) -> int:
738
  "selection_given_retrieval": round(r.selection_given_retrieval, 4),
739
  "over_selection_ratio": round(r.over_selection_ratio, 2),
740
  "why_counts": r.why_counts,
 
741
  # Timing
742
  "stage1_time": round(r.stage1_time, 3),
743
  "stage2_time": round(r.stage2_time, 3),
 
133
  over_selection_ratio: float = 0.0 # |selected| / |gt|
134
  # Why distribution (from Stage 3 LLM)
135
  why_counts: Dict[str, int] = field(default_factory=dict)
136
+ # Tag implications
137
+ implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
138
  # Timing
139
  stage1_time: float = 0.0
140
  stage2_time: float = 0.0
 
173
  verbose: bool,
174
  print_lock: threading.Lock,
175
  min_why: Optional[str] = None,
176
+ expand_implications: bool = False,
177
  ) -> SampleResult:
178
  """Process a single eval sample through the full pipeline. Thread-safe."""
179
  from psq_rag.llm.rewrite import llm_rewrite_prompt
180
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
181
  from psq_rag.llm.select import llm_select_indices
182
+ from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications
183
 
184
  def log(msg: str) -> None:
185
  if verbose:
 
266
  why_counts[w] = why_counts.get(w, 0) + 1
267
  result.why_counts = why_counts
268
 
269
+ # Tag implication expansion (post-Stage 3)
270
+ if expand_implications and result.selected_tags:
271
+ expanded, implied_only = expand_tags_via_implications(result.selected_tags)
272
+ result.implied_tags = implied_only
273
+ result.selected_tags = expanded
274
+ log(f"Implications: +{len(implied_only)} tags")
275
+
276
  # Overall selection metrics
277
  p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
278
  result.selection_precision = p
 
318
  char_info = ""
319
  if gt_char:
320
  char_info = f" char[gt={len(gt_char)} sel={len(sel_char)} P={cp:.2f} R={cr:.2f}]"
321
+ impl_info = f" (+{len(result.implied_tags)} implied)" if result.implied_tags else ""
322
  with print_lock:
323
  print(
324
  f" [{index+1}] retrieval_recall={result.retrieval_recall:.3f} "
325
  f"sel_P={p:.3f} sel_R={r:.3f} sel_F1={f1:.3f} "
326
+ f"selected={len(result.selected_tags)}{impl_info}{char_info} "
327
  f"t1={result.stage1_time:.1f}s t2={result.stage2_time:.1f}s t3={result.stage3_time:.1f}s"
328
  )
329
 
 
341
  get_tfidf_components,
342
  get_tag2aliases,
343
  get_tag_type_name,
344
+ get_tag_implications,
345
  )
346
  print("Pre-warming retrieval assets (TF-IDF, FastText, HNSW, aliases)...")
347
  t0 = time.time()
348
  get_tfidf_components() # loads joblib, HNSW indexes, FastText model
349
  get_tag2aliases() # loads CSV alias dict
350
  get_tag_type_name("_warmup_") # ensures tag type dict is built
351
+ get_tag_implications() # loads implication graph
352
  print(f" Assets loaded in {time.time() - t0:.1f}s")
353
 
354
 
 
367
  seed: int = 42,
368
  workers: int = 1,
369
  min_why: Optional[str] = None,
370
+ expand_implications: bool = False,
371
  ) -> List[SampleResult]:
372
 
373
  # Load eval samples
 
417
  sample, i, total,
418
  skip_rewrite, allow_nsfw, mode, chunk_size,
419
  per_phrase_k, temperature, max_tokens, verbose,
420
+ print_lock, min_why, expand_implications,
421
  )
422
  results.append(result)
423
  else:
 
433
  sample, i, total,
434
  skip_rewrite, allow_nsfw, mode, chunk_size,
435
  per_phrase_k, temperature, max_tokens, verbose,
436
+ print_lock, min_why, expand_implications,
437
  ): i
438
  for i, sample in enumerate(samples)
439
  }
 
501
  if (r.retrieved_tags & r.ground_truth_tags)])
502
  avg_over_sel = _safe_avg([r.over_selection_ratio for r in valid])
503
 
504
+ avg_implied = sum(len(r.implied_tags) for r in valid) / n
505
+
506
  print()
507
  print("Stage 3 - Selection (ALL tags):")
508
  print(f" Avg precision: {avg_sel_precision:.4f}")
509
  print(f" Avg recall: {avg_sel_recall:.4f}")
510
  print(f" Avg F1: {avg_sel_f1:.4f}")
511
  print(f" Avg selected tags: {avg_selected:.1f}")
512
+ if avg_implied > 0:
513
+ print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
514
  print(f" Avg ground-truth tags:{avg_gt:.1f}")
515
  print()
516
  print("Diagnostic Metrics:")
 
671
  ap.add_argument("--min-why", default=None,
672
  choices=["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"],
673
  help="Minimum 'why' confidence to keep (e.g. 'explicit' keeps only explicit matches)")
674
+ ap.add_argument("--expand-implications", action="store_true", default=False,
675
+ help="Expand selected tags via tag implication chains (e.g. fox→canine→canid→mammal)")
676
 
677
  args = ap.parse_args(list(argv) if argv is not None else None)
678
 
 
691
  seed=args.seed,
692
  workers=args.workers,
693
  min_why=args.min_why,
694
+ expand_implications=args.expand_implications,
695
  )
696
 
697
  print_summary(results)
 
723
  "seed": args.seed,
724
  "workers": args.workers,
725
  "min_why": args.min_why,
726
+ "expand_implications": args.expand_implications,
727
  "n_errors": sum(1 for r in results if r.error),
728
  }
729
 
 
760
  "selection_given_retrieval": round(r.selection_given_retrieval, 4),
761
  "over_selection_ratio": round(r.over_selection_ratio, 2),
762
  "why_counts": r.why_counts,
763
+ "implied_tags": sorted(r.implied_tags),
764
  # Timing
765
  "stage1_time": round(r.stage1_time, 3),
766
  "stage2_time": round(r.stage2_time, 3),