Spaces:
Running
Running
Claude commited on
Commit ·
14e5c38
1
Parent(s): 6fc4b56
Normalize GT annotations: expand implications, exclude non-evaluable tags
Browse filesAddresses annotation inconsistency where 30% of GT samples were missing
implied taxonomy tags (e.g. fox present but canid/mammal absent).
- preprocess_eval_data.py: expands GT through implication graph, writes
_expanded.jsonl with tags_ground_truth_expanded field
- eval_pipeline.py: uses expanded GT, strips _EVAL_EXCLUDED_TAGS
(invalid_*, hi_res, structural backgrounds) from both sides,
reports leaf-only metrics alongside expanded metrics
- state.py: adds get_leaf_tags() to strip implied ancestors from a tag set
https://claude.ai/code/session_019PY5TEXTWGtToUbowunSRG
data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
psq_rag/retrieval/state.py
CHANGED
|
@@ -327,6 +327,29 @@ def expand_tags_via_implications(tags: Set[str]) -> Tuple[Set[str], Set[str]]:
|
|
| 327 |
return expanded, implied_only
|
| 328 |
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def get_tfidf_tag_vectors() -> Dict[str, Any]:
|
| 331 |
global _tfidf_tag_vectors
|
| 332 |
if _tfidf_tag_vectors is not None:
|
|
|
|
| 327 |
return expanded, implied_only
|
| 328 |
|
| 329 |
|
| 330 |
+
def get_leaf_tags(tags: Set[str]) -> Set[str]:
|
| 331 |
+
"""Return only leaf tags — those not implied by any other tag in the set.
|
| 332 |
+
|
| 333 |
+
For example, given {fox, canine, canid, mammal}, returns {fox} because
|
| 334 |
+
canine/canid/mammal are all reachable from fox via implications.
|
| 335 |
+
"""
|
| 336 |
+
impl = get_tag_implications()
|
| 337 |
+
# For each tag, compute what it implies; mark those as non-leaves
|
| 338 |
+
non_leaves: Set[str] = set()
|
| 339 |
+
for tag in tags:
|
| 340 |
+
visited: Set[str] = set()
|
| 341 |
+
queue = [tag]
|
| 342 |
+
while queue:
|
| 343 |
+
t = queue.pop()
|
| 344 |
+
for parent in impl.get(t, ()):
|
| 345 |
+
if parent not in visited:
|
| 346 |
+
visited.add(parent)
|
| 347 |
+
if parent in tags:
|
| 348 |
+
non_leaves.add(parent)
|
| 349 |
+
queue.append(parent)
|
| 350 |
+
return tags - non_leaves
|
| 351 |
+
|
| 352 |
+
|
| 353 |
def get_tfidf_tag_vectors() -> Dict[str, Any]:
|
| 354 |
global _tfidf_tag_vectors
|
| 355 |
if _tfidf_tag_vectors is not None:
|
scripts/eval_pipeline.py
CHANGED
|
@@ -57,13 +57,29 @@ if str(_REPO_ROOT) not in sys.path:
|
|
| 57 |
sys.path.insert(0, str(_REPO_ROOT))
|
| 58 |
os.chdir(_REPO_ROOT)
|
| 59 |
|
| 60 |
-
EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "
|
|
|
|
| 61 |
|
| 62 |
# Character tag types that go through the alias filter pipeline
|
| 63 |
_CHARACTER_TYPES = {"character"}
|
| 64 |
# Copyright tags are filtered out entirely
|
| 65 |
_COPYRIGHT_TYPES = {"copyright"}
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def _classify_tags(tags: Set[str], get_type_fn) -> Tuple[Set[str], Set[str]]:
|
| 69 |
"""Split tags into (character_tags, general_tags).
|
|
@@ -135,6 +151,12 @@ class SampleResult:
|
|
| 135 |
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 136 |
# Tag implications
|
| 137 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# Timing
|
| 139 |
stage1_time: float = 0.0
|
| 140 |
stage2_time: float = 0.0
|
|
@@ -179,7 +201,7 @@ def _process_one_sample(
|
|
| 179 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 180 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 181 |
from psq_rag.llm.select import llm_select_indices
|
| 182 |
-
from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications
|
| 183 |
|
| 184 |
def log(msg: str) -> None:
|
| 185 |
if verbose:
|
|
@@ -273,13 +295,27 @@ def _process_one_sample(
|
|
| 273 |
result.selected_tags = expanded
|
| 274 |
log(f"Implications: +{len(implied_only)} tags")
|
| 275 |
|
| 276 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
|
| 278 |
result.selection_precision = p
|
| 279 |
result.selection_recall = r
|
| 280 |
result.selection_f1 = f1
|
| 281 |
|
| 282 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
retrieved_and_gt = result.retrieved_tags & gt_tags
|
| 284 |
selected_and_gt = result.selected_tags & gt_tags
|
| 285 |
if result.retrieved_tags:
|
|
@@ -370,26 +406,41 @@ def run_eval(
|
|
| 370 |
expand_implications: bool = False,
|
| 371 |
) -> List[SampleResult]:
|
| 372 |
|
| 373 |
-
# Load eval samples
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
all_samples = []
|
| 379 |
-
|
|
|
|
| 380 |
for line in f:
|
| 381 |
row = json.loads(line)
|
| 382 |
caption = row.get(caption_field, "")
|
| 383 |
if not caption or not caption.strip():
|
| 384 |
continue
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
if not gt_tags:
|
| 387 |
continue
|
|
|
|
|
|
|
| 388 |
all_samples.append({
|
| 389 |
"id": row.get("id", row.get("row_id", len(all_samples))),
|
| 390 |
"caption": caption.strip(),
|
| 391 |
"gt_tags": gt_tags,
|
| 392 |
})
|
|
|
|
|
|
|
| 393 |
|
| 394 |
if shuffle:
|
| 395 |
rng = random.Random(seed)
|
|
@@ -512,6 +563,21 @@ def print_summary(results: List[SampleResult]) -> None:
|
|
| 512 |
if avg_implied > 0:
|
| 513 |
print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
|
| 514 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
print()
|
| 516 |
print("Diagnostic Metrics:")
|
| 517 |
print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
|
|
@@ -761,6 +827,12 @@ def main(argv=None) -> int:
|
|
| 761 |
"over_selection_ratio": round(r.over_selection_ratio, 2),
|
| 762 |
"why_counts": r.why_counts,
|
| 763 |
"implied_tags": sorted(r.implied_tags),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
# Timing
|
| 765 |
"stage1_time": round(r.stage1_time, 3),
|
| 766 |
"stage2_time": round(r.stage2_time, 3),
|
|
|
|
| 57 |
sys.path.insert(0, str(_REPO_ROOT))
|
| 58 |
os.chdir(_REPO_ROOT)
|
| 59 |
|
| 60 |
+
EVAL_DATA_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl"
|
| 61 |
+
EVAL_DATA_PATH_RAW = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
|
| 62 |
|
| 63 |
# Character tag types that go through the alias filter pipeline
|
| 64 |
_CHARACTER_TYPES = {"character"}
|
| 65 |
# Copyright tags are filtered out entirely
|
| 66 |
_COPYRIGHT_TYPES = {"copyright"}
|
| 67 |
|
| 68 |
+
# Tags excluded from evaluation metrics but NOT removed from the pipeline.
|
| 69 |
+
# These are tags that either: can't be inferred from a caption (resolution,
|
| 70 |
+
# art medium), describe structural properties better handled outside the
|
| 71 |
+
# retrieval pipeline (backgrounds), or are annotation artifacts.
|
| 72 |
+
_EVAL_EXCLUDED_TAGS = frozenset({
|
| 73 |
+
# Annotation artifacts
|
| 74 |
+
"invalid_tag", "invalid_background",
|
| 75 |
+
# Resolution / file meta — not inferrable from caption
|
| 76 |
+
"hi_res", "absurd_res", "low_res", "superabsurd_res",
|
| 77 |
+
# Structural background tags — better recommended independently
|
| 78 |
+
"simple_background", "abstract_background", "detailed_background",
|
| 79 |
+
"gradient_background", "blurred_background", "textured_background",
|
| 80 |
+
"transparent_background", "white_background",
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
|
| 84 |
def _classify_tags(tags: Set[str], get_type_fn) -> Tuple[Set[str], Set[str]]:
|
| 85 |
"""Split tags into (character_tags, general_tags).
|
|
|
|
| 151 |
why_counts: Dict[str, int] = field(default_factory=dict)
|
| 152 |
# Tag implications
|
| 153 |
implied_tags: Set[str] = field(default_factory=set) # tags added via implications (not LLM-selected)
|
| 154 |
+
# Leaf-only metrics (strips implied ancestors from both sides)
|
| 155 |
+
leaf_precision: float = 0.0
|
| 156 |
+
leaf_recall: float = 0.0
|
| 157 |
+
leaf_f1: float = 0.0
|
| 158 |
+
leaf_selected_count: int = 0
|
| 159 |
+
leaf_gt_count: int = 0
|
| 160 |
# Timing
|
| 161 |
stage1_time: float = 0.0
|
| 162 |
stage2_time: float = 0.0
|
|
|
|
| 201 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 202 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 203 |
from psq_rag.llm.select import llm_select_indices
|
| 204 |
+
from psq_rag.retrieval.state import get_tag_type_name, expand_tags_via_implications, get_leaf_tags
|
| 205 |
|
| 206 |
def log(msg: str) -> None:
|
| 207 |
if verbose:
|
|
|
|
| 295 |
result.selected_tags = expanded
|
| 296 |
log(f"Implications: +{len(implied_only)} tags")
|
| 297 |
|
| 298 |
+
# Remove eval-excluded tags from predictions before scoring
|
| 299 |
+
result.selected_tags -= _EVAL_EXCLUDED_TAGS
|
| 300 |
+
result.retrieved_tags -= _EVAL_EXCLUDED_TAGS
|
| 301 |
+
|
| 302 |
+
# Overall selection metrics (expanded — both sides have full implication chains)
|
| 303 |
p, r, f1 = _compute_metrics(result.selected_tags, gt_tags)
|
| 304 |
result.selection_precision = p
|
| 305 |
result.selection_recall = r
|
| 306 |
result.selection_f1 = f1
|
| 307 |
|
| 308 |
+
# Leaf-only metrics (strips implied ancestors from both sides)
|
| 309 |
+
leaf_sel = get_leaf_tags(result.selected_tags)
|
| 310 |
+
leaf_gt = get_leaf_tags(gt_tags)
|
| 311 |
+
lp, lr, lf1 = _compute_metrics(leaf_sel, leaf_gt)
|
| 312 |
+
result.leaf_precision = lp
|
| 313 |
+
result.leaf_recall = lr
|
| 314 |
+
result.leaf_f1 = lf1
|
| 315 |
+
result.leaf_selected_count = len(leaf_sel)
|
| 316 |
+
result.leaf_gt_count = len(leaf_gt)
|
| 317 |
+
|
| 318 |
+
# Diagnostic metrics
|
| 319 |
retrieved_and_gt = result.retrieved_tags & gt_tags
|
| 320 |
selected_and_gt = result.selected_tags & gt_tags
|
| 321 |
if result.retrieved_tags:
|
|
|
|
| 406 |
expand_implications: bool = False,
|
| 407 |
) -> List[SampleResult]:
|
| 408 |
|
| 409 |
+
# Load eval samples — prefer expanded file, fall back to raw
|
| 410 |
+
eval_path = EVAL_DATA_PATH
|
| 411 |
+
if not eval_path.is_file():
|
| 412 |
+
eval_path = EVAL_DATA_PATH_RAW
|
| 413 |
+
if not eval_path.is_file():
|
| 414 |
+
print(f"ERROR: Eval data not found: {EVAL_DATA_PATH}")
|
| 415 |
+
sys.exit(1)
|
| 416 |
+
print(f"WARNING: Expanded eval data not found, falling back to raw: {eval_path}")
|
| 417 |
+
print(" Run: python scripts/preprocess_eval_data.py")
|
| 418 |
|
| 419 |
all_samples = []
|
| 420 |
+
using_expanded = False
|
| 421 |
+
with eval_path.open("r", encoding="utf-8") as f:
|
| 422 |
for line in f:
|
| 423 |
row = json.loads(line)
|
| 424 |
caption = row.get(caption_field, "")
|
| 425 |
if not caption or not caption.strip():
|
| 426 |
continue
|
| 427 |
+
# Prefer pre-expanded GT; fall back to flattening categorized
|
| 428 |
+
if "tags_ground_truth_expanded" in row:
|
| 429 |
+
gt_tags = set(row["tags_ground_truth_expanded"])
|
| 430 |
+
using_expanded = True
|
| 431 |
+
else:
|
| 432 |
+
gt_tags = _flatten_ground_truth_tags(row.get("tags_ground_truth_categorized", ""))
|
| 433 |
if not gt_tags:
|
| 434 |
continue
|
| 435 |
+
# Remove eval-excluded tags from GT
|
| 436 |
+
gt_tags -= _EVAL_EXCLUDED_TAGS
|
| 437 |
all_samples.append({
|
| 438 |
"id": row.get("id", row.get("row_id", len(all_samples))),
|
| 439 |
"caption": caption.strip(),
|
| 440 |
"gt_tags": gt_tags,
|
| 441 |
})
|
| 442 |
+
if using_expanded:
|
| 443 |
+
print("Using implication-expanded ground truth")
|
| 444 |
|
| 445 |
if shuffle:
|
| 446 |
rng = random.Random(seed)
|
|
|
|
| 563 |
if avg_implied > 0:
|
| 564 |
print(f" Avg implied tags: {avg_implied:.1f} (added via tag implications)")
|
| 565 |
print(f" Avg ground-truth tags:{avg_gt:.1f}")
|
| 566 |
+
|
| 567 |
+
# Leaf-only metrics
|
| 568 |
+
avg_leaf_p = _safe_avg([r.leaf_precision for r in valid])
|
| 569 |
+
avg_leaf_r = _safe_avg([r.leaf_recall for r in valid])
|
| 570 |
+
avg_leaf_f1 = _safe_avg([r.leaf_f1 for r in valid])
|
| 571 |
+
avg_leaf_sel = _safe_avg([r.leaf_selected_count for r in valid])
|
| 572 |
+
avg_leaf_gt = _safe_avg([r.leaf_gt_count for r in valid])
|
| 573 |
+
print()
|
| 574 |
+
print("Stage 3 - Selection (LEAF tags only — implied ancestors stripped):")
|
| 575 |
+
print(f" Avg precision: {avg_leaf_p:.4f}")
|
| 576 |
+
print(f" Avg recall: {avg_leaf_r:.4f}")
|
| 577 |
+
print(f" Avg F1: {avg_leaf_f1:.4f}")
|
| 578 |
+
print(f" Avg leaf selected: {avg_leaf_sel:.1f}")
|
| 579 |
+
print(f" Avg leaf ground-truth:{avg_leaf_gt:.1f}")
|
| 580 |
+
|
| 581 |
print()
|
| 582 |
print("Diagnostic Metrics:")
|
| 583 |
print(f" Retrieval precision: {avg_retrieval_precision:.4f} (|ret∩gt|/|ret|, noise level fed to Stage 3)")
|
|
|
|
| 827 |
"over_selection_ratio": round(r.over_selection_ratio, 2),
|
| 828 |
"why_counts": r.why_counts,
|
| 829 |
"implied_tags": sorted(r.implied_tags),
|
| 830 |
+
# Leaf metrics
|
| 831 |
+
"leaf_precision": round(r.leaf_precision, 4),
|
| 832 |
+
"leaf_recall": round(r.leaf_recall, 4),
|
| 833 |
+
"leaf_f1": round(r.leaf_f1, 4),
|
| 834 |
+
"leaf_selected_count": r.leaf_selected_count,
|
| 835 |
+
"leaf_gt_count": r.leaf_gt_count,
|
| 836 |
# Timing
|
| 837 |
"stage1_time": round(r.stage1_time, 3),
|
| 838 |
"stage2_time": round(r.stage2_time, 3),
|
scripts/preprocess_eval_data.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Preprocess eval dataset: expand ground-truth tags through implication chains.
|
| 2 |
+
|
| 3 |
+
Reads the raw eval JSONL, expands each sample's GT tags via the e621 tag
|
| 4 |
+
implication graph, removes known garbage tags, and writes a new JSONL with
|
| 5 |
+
an additional `tags_ground_truth_expanded` field (flat sorted list).
|
| 6 |
+
|
| 7 |
+
The original `tags_ground_truth_categorized` field is preserved unchanged.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python scripts/preprocess_eval_data.py
|
| 11 |
+
|
| 12 |
+
Input: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl
|
| 13 |
+
Output: data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_expanded.jsonl
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
# Add project root to path so we can import psq_rag
|
| 23 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 24 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 25 |
+
|
| 26 |
+
from psq_rag.retrieval.state import expand_tags_via_implications, get_tag_implications
|
| 27 |
+
|
| 28 |
+
# Tags that are annotation artifacts, not real content tags
|
| 29 |
+
GARBAGE_TAGS = frozenset({
|
| 30 |
+
"invalid_tag",
|
| 31 |
+
"invalid_background",
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
INPUT_PATH = _REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
|
| 35 |
+
OUTPUT_PATH = INPUT_PATH.with_name(INPUT_PATH.stem + "_expanded.jsonl")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def flatten_ground_truth(tags_categorized_str: str) -> set[str]:
|
| 39 |
+
"""Parse the categorized ground-truth JSON into a flat set of tags."""
|
| 40 |
+
if not tags_categorized_str:
|
| 41 |
+
return set()
|
| 42 |
+
cats = json.loads(tags_categorized_str)
|
| 43 |
+
tags = set()
|
| 44 |
+
for tag_list in cats.values():
|
| 45 |
+
if isinstance(tag_list, list):
|
| 46 |
+
for t in tag_list:
|
| 47 |
+
tags.add(t.strip())
|
| 48 |
+
return tags
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main() -> int:
|
| 52 |
+
if not INPUT_PATH.is_file():
|
| 53 |
+
print(f"ERROR: Input not found: {INPUT_PATH}")
|
| 54 |
+
return 1
|
| 55 |
+
|
| 56 |
+
# Pre-warm implication graph
|
| 57 |
+
impl = get_tag_implications()
|
| 58 |
+
print(f"Loaded {sum(len(v) for v in impl.values())} active implications")
|
| 59 |
+
|
| 60 |
+
samples_read = 0
|
| 61 |
+
samples_expanded = 0
|
| 62 |
+
total_tags_added = 0
|
| 63 |
+
total_garbage_removed = 0
|
| 64 |
+
|
| 65 |
+
with INPUT_PATH.open("r", encoding="utf-8") as fin, \
|
| 66 |
+
OUTPUT_PATH.open("w", encoding="utf-8") as fout:
|
| 67 |
+
for line in fin:
|
| 68 |
+
row = json.loads(line)
|
| 69 |
+
samples_read += 1
|
| 70 |
+
|
| 71 |
+
gt_raw = flatten_ground_truth(row.get("tags_ground_truth_categorized", ""))
|
| 72 |
+
|
| 73 |
+
# Remove garbage tags
|
| 74 |
+
garbage_found = gt_raw & GARBAGE_TAGS
|
| 75 |
+
if garbage_found:
|
| 76 |
+
total_garbage_removed += len(garbage_found)
|
| 77 |
+
gt_raw -= garbage_found
|
| 78 |
+
|
| 79 |
+
# Expand through implications
|
| 80 |
+
gt_expanded, implied_only = expand_tags_via_implications(gt_raw)
|
| 81 |
+
if implied_only:
|
| 82 |
+
samples_expanded += 1
|
| 83 |
+
total_tags_added += len(implied_only)
|
| 84 |
+
|
| 85 |
+
# Store expanded flat list alongside original categorized field
|
| 86 |
+
row["tags_ground_truth_expanded"] = sorted(gt_expanded)
|
| 87 |
+
|
| 88 |
+
fout.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 89 |
+
|
| 90 |
+
print(f"Processed {samples_read} samples")
|
| 91 |
+
print(f" {samples_expanded} samples had missing implications ({samples_expanded}/{samples_read} = {100*samples_expanded/samples_read:.1f}%)")
|
| 92 |
+
print(f" {total_tags_added} implied tags added total (avg {total_tags_added/samples_read:.1f} per sample)")
|
| 93 |
+
print(f" {total_garbage_removed} garbage tags removed")
|
| 94 |
+
print(f"Output: {OUTPUT_PATH}")
|
| 95 |
+
return 0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
sys.exit(main())
|