Prompt_Squirrel_RAG / scripts /eval_categorized.py
Claude
Add ranking metrics infrastructure to eval pipeline
0ed7e94
"""
Evaluate tag predictions with per-category metrics.
Computes:
- Per-category Precision, Recall, F1 (or Accuracy for EXACTLY_ONE categories)
- Per-category Recall@K, Precision@K, MRR for ranked suggestions
- Results organized by category importance (Critical → Important → Nice-to-have)
Usage:
python scripts/eval_categorized.py \
--results data/eval_results/eval_caption_cogvlm_n50_seed42_*.jsonl \
--k 5
This script takes existing eval results (from eval_pipeline.py) and computes
category-specific metrics using the e621 checklist categorization.
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Set, Tuple, Optional
_REPO_ROOT = Path(__file__).resolve().parents[1]
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from psq_rag.tagging.category_parser import parse_checklist, TagCategory
def build_category_tag_index(categories: Dict[str, TagCategory]) -> Dict[str, str]:
"""
Build reverse index: tag -> category_name.
Args:
categories: Category definitions
Returns:
Dict mapping tag -> category_name
"""
tag_to_category = {}
for cat_name, category in categories.items():
for tag in category.tags:
# Normalize tag (the checklist has underscores, tags might have spaces or underscores)
normalized = tag.replace('_', ' ')
tag_to_category[normalized] = cat_name
tag_to_category[tag] = cat_name
return tag_to_category
# Category importance levels (for display ordering)
CATEGORY_IMPORTANCE = {
# Critical
'count': 1,
'species': 1,
'body_type': 1,
# Important
'gender': 2,
'clothing': 2,
'posture': 2,
'location': 2,
'perspective': 2,
# Nice-to-have
'expression': 3,
'limbs': 3,
'gaze': 3,
'fur_style': 3,
'hair': 3,
'body_decor': 3,
'breasts': 3,
'general_activity_if_any': 3,
# Meta
'quality': 4,
'style': 4,
'organization': 4,
'text': 4,
'information': 4,
'requests': 4,
'resolution': 4,
}
IMPORTANCE_LABELS = {
1: "CRITICAL",
2: "IMPORTANT",
3: "NICE-TO-HAVE",
4: "META",
}
@dataclass
class CategoryMetrics:
"""Metrics for a single category."""
category_name: str
display_name: str
importance: int
# Binary prediction metrics
tp: int = 0 # True positives
fp: int = 0 # False positives
fn: int = 0 # False negatives
tn: int = 0 # True negatives
# Ranking metrics (for suggestions)
total_gt_tags: int = 0 # Total ground truth tags across all samples
found_in_suggestions: int = 0 # GT tags that appear anywhere in suggestions
recall_at_k: float = 0.0 # Fraction of GT tags found in top-K
precision_at_k: float = 0.0 # Fraction of top-K that are correct
mrr: float = 0.0 # Mean reciprocal rank
mrr_count: int = 0 # Number of GT tags used for MRR calculation
@property
def precision(self) -> float:
"""Precision = TP / (TP + FP)"""
if self.tp + self.fp == 0:
return 0.0
return self.tp / (self.tp + self.fp)
@property
def recall(self) -> float:
"""Recall = TP / (TP + FN)"""
if self.tp + self.fn == 0:
return 0.0
return self.tp / (self.tp + self.fn)
@property
def f1(self) -> float:
"""F1 = 2 * (P * R) / (P + R)"""
p, r = self.precision, self.recall
if p + r == 0:
return 0.0
return 2 * p * r / (p + r)
@property
def accuracy(self) -> float:
"""Accuracy = (TP + TN) / (TP + TN + FP + FN)"""
total = self.tp + self.tn + self.fp + self.fn
if total == 0:
return 0.0
return (self.tp + self.tn) / total
def compute_category_metrics(
eval_results: List[Dict],
categories: Dict[str, TagCategory],
tag_to_category: Dict[str, str],
k: int = 5,
) -> Dict[str, CategoryMetrics]:
"""
Compute per-category metrics from eval results.
Args:
eval_results: List of evaluation result dicts from eval_pipeline.py
categories: Category definitions from checklist
tag_to_category: Mapping from tag to category name
k: Top-K for ranking metrics
Returns:
Dict mapping category_name -> CategoryMetrics
"""
# Initialize metrics for each category
metrics: Dict[str, CategoryMetrics] = {}
for cat_name, category in categories.items():
importance = CATEGORY_IMPORTANCE.get(cat_name, 5)
metrics[cat_name] = CategoryMetrics(
category_name=cat_name,
display_name=category.display_name,
importance=importance,
)
# Process each evaluation sample
for result in eval_results:
# Get ground truth and predicted tags
gt_tags = set(result.get('ground_truth_tags', []))
pred_tags = set(result.get('selected_tags', []))
# Organize by category
gt_by_category = defaultdict(set)
pred_by_category = defaultdict(set)
for tag in gt_tags:
cat = tag_to_category.get(tag)
if cat:
gt_by_category[cat].add(tag)
for tag in pred_tags:
cat = tag_to_category.get(tag)
if cat:
pred_by_category[cat].add(tag)
# Compute metrics per category for this sample
for cat_name, category in categories.items():
cat_metric = metrics[cat_name]
gt_cat_tags = gt_by_category[cat_name]
pred_cat_tags = pred_by_category[cat_name]
# Binary prediction metrics
tp = len(gt_cat_tags & pred_cat_tags) # Correct predictions
fp = len(pred_cat_tags - gt_cat_tags) # Wrong predictions
fn = len(gt_cat_tags - pred_cat_tags) # Missed tags
cat_metric.tp += tp
cat_metric.fp += fp
cat_metric.fn += fn
# For EXACTLY_ONE categories, also track TN (correct negatives)
if category.constraint.value == "exactly_one":
# All other options in this category that weren't predicted or in GT
all_options = set(category.tags)
tn_tags = all_options - gt_cat_tags - pred_cat_tags
cat_metric.tn += len(tn_tags)
cat_metric.total_gt_tags += len(gt_cat_tags)
# Ranking metrics (if categorized_suggestions are available)
categorized_suggestions = result.get('categorized_suggestions', {})
cat_suggestions = categorized_suggestions.get(cat_name, [])
if cat_suggestions and gt_cat_tags:
# Convert to dict for easier lookup: {tag: rank}
# Suggestions are already sorted by score, so index = rank (0-indexed)
suggestion_ranks = {tag: rank for rank, (tag, score) in enumerate(cat_suggestions)}
# Count how many GT tags appear in suggestions (at any rank)
found_count = sum(1 for gt_tag in gt_cat_tags if gt_tag in suggestion_ranks)
cat_metric.found_in_suggestions += found_count
# Recall@K: fraction of GT tags in top-K
top_k_tags = {tag for tag, score in cat_suggestions[:k]}
recall_at_k_count = len(gt_cat_tags & top_k_tags)
# Precision@K: fraction of top-K that are in GT
if len(top_k_tags) > 0:
precision_at_k_count = len(top_k_tags & gt_cat_tags)
else:
precision_at_k_count = 0
# MRR: mean of 1/rank for each GT tag found in suggestions
reciprocal_ranks = []
for gt_tag in gt_cat_tags:
if gt_tag in suggestion_ranks:
rank = suggestion_ranks[gt_tag]
reciprocal_ranks.append(1.0 / (rank + 1)) # +1 because rank is 0-indexed
# Accumulate for averaging later
cat_metric.recall_at_k += recall_at_k_count / len(gt_cat_tags) if gt_cat_tags else 0
cat_metric.precision_at_k += precision_at_k_count / min(k, len(cat_suggestions)) if cat_suggestions else 0
if reciprocal_ranks:
cat_metric.mrr += sum(reciprocal_ranks) / len(reciprocal_ranks)
cat_metric.mrr_count += 1
return metrics
def print_category_metrics(
metrics: Dict[str, CategoryMetrics],
categories: Dict[str, TagCategory],
n_samples: int,
k: int,
):
"""
Print metrics organized by importance.
Args:
metrics: Category metrics
categories: Category definitions
n_samples: Number of samples evaluated
k: Top-K for ranking metrics
"""
# Group by importance level
by_importance = defaultdict(list)
for cat_name, cat_metric in metrics.items():
by_importance[cat_metric.importance].append(cat_metric)
# Print in order of importance
for importance in sorted(by_importance.keys()):
label = IMPORTANCE_LABELS.get(importance, "OTHER")
print(f"\n{'='*80}")
print(f"{label} CATEGORIES")
print('='*80)
cat_metrics = by_importance[importance]
cat_metrics.sort(key=lambda m: m.category_name)
for cat_metric in cat_metrics:
category = categories[cat_metric.category_name]
print(f"\n{cat_metric.display_name} ({cat_metric.category_name})")
print(f" Constraint: {category.constraint.value}")
print(f" Ground truth tags: {cat_metric.total_gt_tags}")
# Binary prediction metrics
if category.constraint.value == "exactly_one":
print(f" Accuracy: {cat_metric.accuracy:.3f}")
print(f" Precision: {cat_metric.precision:.3f}")
print(f" Recall: {cat_metric.recall:.3f}")
print(f" F1: {cat_metric.f1:.3f}")
# Ranking metrics (averaged across samples)
if cat_metric.mrr_count > 0:
avg_recall_at_k = cat_metric.recall_at_k / n_samples if n_samples > 0 else 0
avg_precision_at_k = cat_metric.precision_at_k / n_samples if n_samples > 0 else 0
avg_mrr = cat_metric.mrr / cat_metric.mrr_count
print(f" Recall@{k}: {avg_recall_at_k:.3f} (GT tags found in top-{k})")
print(f" Precision@{k}: {avg_precision_at_k:.3f} (top-{k} that are correct)")
print(f" MRR: {avg_mrr:.3f} (mean reciprocal rank)")
print(f" Coverage: {cat_metric.found_in_suggestions}/{cat_metric.total_gt_tags} (GT tags in suggestions)")
# Show raw counts for debugging
print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})")
print(f"\n{'='*80}")
print("SUMMARY")
print('='*80)
# Aggregate by importance level
for importance in sorted(by_importance.keys()):
label = IMPORTANCE_LABELS.get(importance, "OTHER")
cat_metrics = by_importance[importance]
total_tp = sum(m.tp for m in cat_metrics)
total_fp = sum(m.fp for m in cat_metrics)
total_fn = sum(m.fn for m in cat_metrics)
total_gt = sum(m.total_gt_tags for m in cat_metrics)
# Micro-averaged metrics (aggregate then calculate)
micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0
# Macro-averaged metrics (average across categories)
precisions = [m.precision for m in cat_metrics if m.tp + m.fp > 0]
recalls = [m.recall for m in cat_metrics if m.tp + m.fn > 0]
f1s = [m.f1 for m in cat_metrics if m.tp + m.fn > 0]
macro_p = sum(precisions) / len(precisions) if precisions else 0
macro_r = sum(recalls) / len(recalls) if recalls else 0
macro_f1 = sum(f1s) / len(f1s) if f1s else 0
print(f"\n{label}:")
print(f" Total GT tags: {total_gt}")
print(f" Micro-avg P/R/F1: {micro_p:.3f} / {micro_r:.3f} / {micro_f1:.3f}")
print(f" Macro-avg P/R/F1: {macro_p:.3f} / {macro_r:.3f} / {macro_f1:.3f}")
print(f"\n{'='*80}")
def main():
parser = argparse.ArgumentParser(
description="Compute per-category evaluation metrics"
)
parser.add_argument(
"--results",
required=True,
help="Path to eval results JSONL file from eval_pipeline.py"
)
parser.add_argument(
"--checklist",
default=str(_REPO_ROOT / "tagging_checklist.txt"),
help="Path to e621 tagging checklist"
)
parser.add_argument(
"--k",
type=int,
default=5,
help="Top-K for ranking metrics (default: 5)"
)
parser.add_argument(
"--skip-rating",
action="store_true",
default=True,
help="Skip rating category in evaluation (dataset is rating:safe only)"
)
args = parser.parse_args()
# Load category definitions
checklist_path = Path(args.checklist)
if not checklist_path.exists():
print(f"Error: Checklist not found at {checklist_path}")
sys.exit(1)
categories = parse_checklist(checklist_path)
# Remove rating if requested
if args.skip_rating and 'rating' in categories:
del categories['rating']
tag_to_category = build_category_tag_index(categories)
# Load eval results
results_path = Path(args.results)
if not results_path.exists():
print(f"Error: Results file not found at {results_path}")
sys.exit(1)
eval_results = []
with open(results_path, 'r') as f:
for line in f:
if line.strip():
result = json.loads(line)
# Skip metadata lines
if not result.get('_meta', False):
eval_results.append(result)
print(f"Loaded {len(eval_results)} evaluation results from {results_path}")
# Compute metrics
metrics = compute_category_metrics(
eval_results,
categories,
tag_to_category,
k=args.k,
)
# Print results
print_category_metrics(metrics, categories, len(eval_results), args.k)
if __name__ == "__main__":
main()