""" Evaluate tag predictions with per-category metrics. Computes: - Per-category Precision, Recall, F1 (or Accuracy for EXACTLY_ONE categories) - Per-category Recall@K, Precision@K, MRR for ranked suggestions - Results organized by category importance (Critical → Important → Nice-to-have) Usage: python scripts/eval_categorized.py \ --results data/eval_results/eval_caption_cogvlm_n50_seed42_*.jsonl \ --k 5 This script takes existing eval results (from eval_pipeline.py) and computes category-specific metrics using the e621 checklist categorization. """ from __future__ import annotations import argparse import json import sys from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Set, Tuple, Optional _REPO_ROOT = Path(__file__).resolve().parents[1] if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from psq_rag.tagging.category_parser import parse_checklist, TagCategory def build_category_tag_index(categories: Dict[str, TagCategory]) -> Dict[str, str]: """ Build reverse index: tag -> category_name. Args: categories: Category definitions Returns: Dict mapping tag -> category_name """ tag_to_category = {} for cat_name, category in categories.items(): for tag in category.tags: # Normalize tag (the checklist has underscores, tags might have spaces or underscores) normalized = tag.replace('_', ' ') tag_to_category[normalized] = cat_name tag_to_category[tag] = cat_name return tag_to_category # Category importance levels (for display ordering) CATEGORY_IMPORTANCE = { # Critical 'count': 1, 'species': 1, 'body_type': 1, # Important 'gender': 2, 'clothing': 2, 'posture': 2, 'location': 2, 'perspective': 2, # Nice-to-have 'expression': 3, 'limbs': 3, 'gaze': 3, 'fur_style': 3, 'hair': 3, 'body_decor': 3, 'breasts': 3, 'general_activity_if_any': 3, # Meta 'quality': 4, 'style': 4, 'organization': 4, 'text': 4, 'information': 4, 'requests': 4, 'resolution': 4, } IMPORTANCE_LABELS = { 1: "CRITICAL", 2: "IMPORTANT", 3: "NICE-TO-HAVE", 4: "META", } @dataclass class CategoryMetrics: """Metrics for a single category.""" category_name: str display_name: str importance: int # Binary prediction metrics tp: int = 0 # True positives fp: int = 0 # False positives fn: int = 0 # False negatives tn: int = 0 # True negatives # Ranking metrics (for suggestions) total_gt_tags: int = 0 # Total ground truth tags across all samples found_in_suggestions: int = 0 # GT tags that appear anywhere in suggestions recall_at_k: float = 0.0 # Fraction of GT tags found in top-K precision_at_k: float = 0.0 # Fraction of top-K that are correct mrr: float = 0.0 # Mean reciprocal rank mrr_count: int = 0 # Number of GT tags used for MRR calculation @property def precision(self) -> float: """Precision = TP / (TP + FP)""" if self.tp + self.fp == 0: return 0.0 return self.tp / (self.tp + self.fp) @property def recall(self) -> float: """Recall = TP / (TP + FN)""" if self.tp + self.fn == 0: return 0.0 return self.tp / (self.tp + self.fn) @property def f1(self) -> float: """F1 = 2 * (P * R) / (P + R)""" p, r = self.precision, self.recall if p + r == 0: return 0.0 return 2 * p * r / (p + r) @property def accuracy(self) -> float: """Accuracy = (TP + TN) / (TP + TN + FP + FN)""" total = self.tp + self.tn + self.fp + self.fn if total == 0: return 0.0 return (self.tp + self.tn) / total def compute_category_metrics( eval_results: List[Dict], categories: Dict[str, TagCategory], tag_to_category: Dict[str, str], k: int = 5, ) -> Dict[str, CategoryMetrics]: """ Compute per-category metrics from eval results. Args: eval_results: List of evaluation result dicts from eval_pipeline.py categories: Category definitions from checklist tag_to_category: Mapping from tag to category name k: Top-K for ranking metrics Returns: Dict mapping category_name -> CategoryMetrics """ # Initialize metrics for each category metrics: Dict[str, CategoryMetrics] = {} for cat_name, category in categories.items(): importance = CATEGORY_IMPORTANCE.get(cat_name, 5) metrics[cat_name] = CategoryMetrics( category_name=cat_name, display_name=category.display_name, importance=importance, ) # Process each evaluation sample for result in eval_results: # Get ground truth and predicted tags gt_tags = set(result.get('ground_truth_tags', [])) pred_tags = set(result.get('selected_tags', [])) # Organize by category gt_by_category = defaultdict(set) pred_by_category = defaultdict(set) for tag in gt_tags: cat = tag_to_category.get(tag) if cat: gt_by_category[cat].add(tag) for tag in pred_tags: cat = tag_to_category.get(tag) if cat: pred_by_category[cat].add(tag) # Compute metrics per category for this sample for cat_name, category in categories.items(): cat_metric = metrics[cat_name] gt_cat_tags = gt_by_category[cat_name] pred_cat_tags = pred_by_category[cat_name] # Binary prediction metrics tp = len(gt_cat_tags & pred_cat_tags) # Correct predictions fp = len(pred_cat_tags - gt_cat_tags) # Wrong predictions fn = len(gt_cat_tags - pred_cat_tags) # Missed tags cat_metric.tp += tp cat_metric.fp += fp cat_metric.fn += fn # For EXACTLY_ONE categories, also track TN (correct negatives) if category.constraint.value == "exactly_one": # All other options in this category that weren't predicted or in GT all_options = set(category.tags) tn_tags = all_options - gt_cat_tags - pred_cat_tags cat_metric.tn += len(tn_tags) cat_metric.total_gt_tags += len(gt_cat_tags) # Ranking metrics (if categorized_suggestions are available) categorized_suggestions = result.get('categorized_suggestions', {}) cat_suggestions = categorized_suggestions.get(cat_name, []) if cat_suggestions and gt_cat_tags: # Convert to dict for easier lookup: {tag: rank} # Suggestions are already sorted by score, so index = rank (0-indexed) suggestion_ranks = {tag: rank for rank, (tag, score) in enumerate(cat_suggestions)} # Count how many GT tags appear in suggestions (at any rank) found_count = sum(1 for gt_tag in gt_cat_tags if gt_tag in suggestion_ranks) cat_metric.found_in_suggestions += found_count # Recall@K: fraction of GT tags in top-K top_k_tags = {tag for tag, score in cat_suggestions[:k]} recall_at_k_count = len(gt_cat_tags & top_k_tags) # Precision@K: fraction of top-K that are in GT if len(top_k_tags) > 0: precision_at_k_count = len(top_k_tags & gt_cat_tags) else: precision_at_k_count = 0 # MRR: mean of 1/rank for each GT tag found in suggestions reciprocal_ranks = [] for gt_tag in gt_cat_tags: if gt_tag in suggestion_ranks: rank = suggestion_ranks[gt_tag] reciprocal_ranks.append(1.0 / (rank + 1)) # +1 because rank is 0-indexed # Accumulate for averaging later cat_metric.recall_at_k += recall_at_k_count / len(gt_cat_tags) if gt_cat_tags else 0 cat_metric.precision_at_k += precision_at_k_count / min(k, len(cat_suggestions)) if cat_suggestions else 0 if reciprocal_ranks: cat_metric.mrr += sum(reciprocal_ranks) / len(reciprocal_ranks) cat_metric.mrr_count += 1 return metrics def print_category_metrics( metrics: Dict[str, CategoryMetrics], categories: Dict[str, TagCategory], n_samples: int, k: int, ): """ Print metrics organized by importance. Args: metrics: Category metrics categories: Category definitions n_samples: Number of samples evaluated k: Top-K for ranking metrics """ # Group by importance level by_importance = defaultdict(list) for cat_name, cat_metric in metrics.items(): by_importance[cat_metric.importance].append(cat_metric) # Print in order of importance for importance in sorted(by_importance.keys()): label = IMPORTANCE_LABELS.get(importance, "OTHER") print(f"\n{'='*80}") print(f"{label} CATEGORIES") print('='*80) cat_metrics = by_importance[importance] cat_metrics.sort(key=lambda m: m.category_name) for cat_metric in cat_metrics: category = categories[cat_metric.category_name] print(f"\n{cat_metric.display_name} ({cat_metric.category_name})") print(f" Constraint: {category.constraint.value}") print(f" Ground truth tags: {cat_metric.total_gt_tags}") # Binary prediction metrics if category.constraint.value == "exactly_one": print(f" Accuracy: {cat_metric.accuracy:.3f}") print(f" Precision: {cat_metric.precision:.3f}") print(f" Recall: {cat_metric.recall:.3f}") print(f" F1: {cat_metric.f1:.3f}") # Ranking metrics (averaged across samples) if cat_metric.mrr_count > 0: avg_recall_at_k = cat_metric.recall_at_k / n_samples if n_samples > 0 else 0 avg_precision_at_k = cat_metric.precision_at_k / n_samples if n_samples > 0 else 0 avg_mrr = cat_metric.mrr / cat_metric.mrr_count print(f" Recall@{k}: {avg_recall_at_k:.3f} (GT tags found in top-{k})") print(f" Precision@{k}: {avg_precision_at_k:.3f} (top-{k} that are correct)") print(f" MRR: {avg_mrr:.3f} (mean reciprocal rank)") print(f" Coverage: {cat_metric.found_in_suggestions}/{cat_metric.total_gt_tags} (GT tags in suggestions)") # Show raw counts for debugging print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})") print(f"\n{'='*80}") print("SUMMARY") print('='*80) # Aggregate by importance level for importance in sorted(by_importance.keys()): label = IMPORTANCE_LABELS.get(importance, "OTHER") cat_metrics = by_importance[importance] total_tp = sum(m.tp for m in cat_metrics) total_fp = sum(m.fp for m in cat_metrics) total_fn = sum(m.fn for m in cat_metrics) total_gt = sum(m.total_gt_tags for m in cat_metrics) # Micro-averaged metrics (aggregate then calculate) micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0 micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0 micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0 # Macro-averaged metrics (average across categories) precisions = [m.precision for m in cat_metrics if m.tp + m.fp > 0] recalls = [m.recall for m in cat_metrics if m.tp + m.fn > 0] f1s = [m.f1 for m in cat_metrics if m.tp + m.fn > 0] macro_p = sum(precisions) / len(precisions) if precisions else 0 macro_r = sum(recalls) / len(recalls) if recalls else 0 macro_f1 = sum(f1s) / len(f1s) if f1s else 0 print(f"\n{label}:") print(f" Total GT tags: {total_gt}") print(f" Micro-avg P/R/F1: {micro_p:.3f} / {micro_r:.3f} / {micro_f1:.3f}") print(f" Macro-avg P/R/F1: {macro_p:.3f} / {macro_r:.3f} / {macro_f1:.3f}") print(f"\n{'='*80}") def main(): parser = argparse.ArgumentParser( description="Compute per-category evaluation metrics" ) parser.add_argument( "--results", required=True, help="Path to eval results JSONL file from eval_pipeline.py" ) parser.add_argument( "--checklist", default=str(_REPO_ROOT / "tagging_checklist.txt"), help="Path to e621 tagging checklist" ) parser.add_argument( "--k", type=int, default=5, help="Top-K for ranking metrics (default: 5)" ) parser.add_argument( "--skip-rating", action="store_true", default=True, help="Skip rating category in evaluation (dataset is rating:safe only)" ) args = parser.parse_args() # Load category definitions checklist_path = Path(args.checklist) if not checklist_path.exists(): print(f"Error: Checklist not found at {checklist_path}") sys.exit(1) categories = parse_checklist(checklist_path) # Remove rating if requested if args.skip_rating and 'rating' in categories: del categories['rating'] tag_to_category = build_category_tag_index(categories) # Load eval results results_path = Path(args.results) if not results_path.exists(): print(f"Error: Results file not found at {results_path}") sys.exit(1) eval_results = [] with open(results_path, 'r') as f: for line in f: if line.strip(): result = json.loads(line) # Skip metadata lines if not result.get('_meta', False): eval_results.append(result) print(f"Loaded {len(eval_results)} evaluation results from {results_path}") # Compute metrics metrics = compute_category_metrics( eval_results, categories, tag_to_category, k=args.k, ) # Print results print_category_metrics(metrics, categories, len(eval_results), args.k) if __name__ == "__main__": main()