Spaces:
Running
Running
| """ | |
| Evaluate tag predictions with per-category metrics. | |
| Computes: | |
| - Per-category Precision, Recall, F1 (or Accuracy for EXACTLY_ONE categories) | |
| - Per-category Recall@K, Precision@K, MRR for ranked suggestions | |
| - Results organized by category importance (Critical → Important → Nice-to-have) | |
| Usage: | |
| python scripts/eval_categorized.py \ | |
| --results data/eval_results/eval_caption_cogvlm_n50_seed42_*.jsonl \ | |
| --k 5 | |
| This script takes existing eval results (from eval_pipeline.py) and computes | |
| category-specific metrics using the e621 checklist categorization. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple, Optional | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| from psq_rag.tagging.category_parser import parse_checklist, TagCategory | |
| def build_category_tag_index(categories: Dict[str, TagCategory]) -> Dict[str, str]: | |
| """ | |
| Build reverse index: tag -> category_name. | |
| Args: | |
| categories: Category definitions | |
| Returns: | |
| Dict mapping tag -> category_name | |
| """ | |
| tag_to_category = {} | |
| for cat_name, category in categories.items(): | |
| for tag in category.tags: | |
| # Normalize tag (the checklist has underscores, tags might have spaces or underscores) | |
| normalized = tag.replace('_', ' ') | |
| tag_to_category[normalized] = cat_name | |
| tag_to_category[tag] = cat_name | |
| return tag_to_category | |
| # Category importance levels (for display ordering) | |
| CATEGORY_IMPORTANCE = { | |
| # Critical | |
| 'count': 1, | |
| 'species': 1, | |
| 'body_type': 1, | |
| # Important | |
| 'gender': 2, | |
| 'clothing': 2, | |
| 'posture': 2, | |
| 'location': 2, | |
| 'perspective': 2, | |
| # Nice-to-have | |
| 'expression': 3, | |
| 'limbs': 3, | |
| 'gaze': 3, | |
| 'fur_style': 3, | |
| 'hair': 3, | |
| 'body_decor': 3, | |
| 'breasts': 3, | |
| 'general_activity_if_any': 3, | |
| # Meta | |
| 'quality': 4, | |
| 'style': 4, | |
| 'organization': 4, | |
| 'text': 4, | |
| 'information': 4, | |
| 'requests': 4, | |
| 'resolution': 4, | |
| } | |
| IMPORTANCE_LABELS = { | |
| 1: "CRITICAL", | |
| 2: "IMPORTANT", | |
| 3: "NICE-TO-HAVE", | |
| 4: "META", | |
| } | |
| class CategoryMetrics: | |
| """Metrics for a single category.""" | |
| category_name: str | |
| display_name: str | |
| importance: int | |
| # Binary prediction metrics | |
| tp: int = 0 # True positives | |
| fp: int = 0 # False positives | |
| fn: int = 0 # False negatives | |
| tn: int = 0 # True negatives | |
| # Ranking metrics (for suggestions) | |
| total_gt_tags: int = 0 # Total ground truth tags across all samples | |
| found_in_suggestions: int = 0 # GT tags that appear anywhere in suggestions | |
| recall_at_k: float = 0.0 # Fraction of GT tags found in top-K | |
| precision_at_k: float = 0.0 # Fraction of top-K that are correct | |
| mrr: float = 0.0 # Mean reciprocal rank | |
| mrr_count: int = 0 # Number of GT tags used for MRR calculation | |
| def precision(self) -> float: | |
| """Precision = TP / (TP + FP)""" | |
| if self.tp + self.fp == 0: | |
| return 0.0 | |
| return self.tp / (self.tp + self.fp) | |
| def recall(self) -> float: | |
| """Recall = TP / (TP + FN)""" | |
| if self.tp + self.fn == 0: | |
| return 0.0 | |
| return self.tp / (self.tp + self.fn) | |
| def f1(self) -> float: | |
| """F1 = 2 * (P * R) / (P + R)""" | |
| p, r = self.precision, self.recall | |
| if p + r == 0: | |
| return 0.0 | |
| return 2 * p * r / (p + r) | |
| def accuracy(self) -> float: | |
| """Accuracy = (TP + TN) / (TP + TN + FP + FN)""" | |
| total = self.tp + self.tn + self.fp + self.fn | |
| if total == 0: | |
| return 0.0 | |
| return (self.tp + self.tn) / total | |
| def compute_category_metrics( | |
| eval_results: List[Dict], | |
| categories: Dict[str, TagCategory], | |
| tag_to_category: Dict[str, str], | |
| k: int = 5, | |
| ) -> Dict[str, CategoryMetrics]: | |
| """ | |
| Compute per-category metrics from eval results. | |
| Args: | |
| eval_results: List of evaluation result dicts from eval_pipeline.py | |
| categories: Category definitions from checklist | |
| tag_to_category: Mapping from tag to category name | |
| k: Top-K for ranking metrics | |
| Returns: | |
| Dict mapping category_name -> CategoryMetrics | |
| """ | |
| # Initialize metrics for each category | |
| metrics: Dict[str, CategoryMetrics] = {} | |
| for cat_name, category in categories.items(): | |
| importance = CATEGORY_IMPORTANCE.get(cat_name, 5) | |
| metrics[cat_name] = CategoryMetrics( | |
| category_name=cat_name, | |
| display_name=category.display_name, | |
| importance=importance, | |
| ) | |
| # Process each evaluation sample | |
| for result in eval_results: | |
| # Get ground truth and predicted tags | |
| gt_tags = set(result.get('ground_truth_tags', [])) | |
| pred_tags = set(result.get('selected_tags', [])) | |
| # Organize by category | |
| gt_by_category = defaultdict(set) | |
| pred_by_category = defaultdict(set) | |
| for tag in gt_tags: | |
| cat = tag_to_category.get(tag) | |
| if cat: | |
| gt_by_category[cat].add(tag) | |
| for tag in pred_tags: | |
| cat = tag_to_category.get(tag) | |
| if cat: | |
| pred_by_category[cat].add(tag) | |
| # Compute metrics per category for this sample | |
| for cat_name, category in categories.items(): | |
| cat_metric = metrics[cat_name] | |
| gt_cat_tags = gt_by_category[cat_name] | |
| pred_cat_tags = pred_by_category[cat_name] | |
| # Binary prediction metrics | |
| tp = len(gt_cat_tags & pred_cat_tags) # Correct predictions | |
| fp = len(pred_cat_tags - gt_cat_tags) # Wrong predictions | |
| fn = len(gt_cat_tags - pred_cat_tags) # Missed tags | |
| cat_metric.tp += tp | |
| cat_metric.fp += fp | |
| cat_metric.fn += fn | |
| # For EXACTLY_ONE categories, also track TN (correct negatives) | |
| if category.constraint.value == "exactly_one": | |
| # All other options in this category that weren't predicted or in GT | |
| all_options = set(category.tags) | |
| tn_tags = all_options - gt_cat_tags - pred_cat_tags | |
| cat_metric.tn += len(tn_tags) | |
| cat_metric.total_gt_tags += len(gt_cat_tags) | |
| # Ranking metrics (if categorized_suggestions are available) | |
| categorized_suggestions = result.get('categorized_suggestions', {}) | |
| cat_suggestions = categorized_suggestions.get(cat_name, []) | |
| if cat_suggestions and gt_cat_tags: | |
| # Convert to dict for easier lookup: {tag: rank} | |
| # Suggestions are already sorted by score, so index = rank (0-indexed) | |
| suggestion_ranks = {tag: rank for rank, (tag, score) in enumerate(cat_suggestions)} | |
| # Count how many GT tags appear in suggestions (at any rank) | |
| found_count = sum(1 for gt_tag in gt_cat_tags if gt_tag in suggestion_ranks) | |
| cat_metric.found_in_suggestions += found_count | |
| # Recall@K: fraction of GT tags in top-K | |
| top_k_tags = {tag for tag, score in cat_suggestions[:k]} | |
| recall_at_k_count = len(gt_cat_tags & top_k_tags) | |
| # Precision@K: fraction of top-K that are in GT | |
| if len(top_k_tags) > 0: | |
| precision_at_k_count = len(top_k_tags & gt_cat_tags) | |
| else: | |
| precision_at_k_count = 0 | |
| # MRR: mean of 1/rank for each GT tag found in suggestions | |
| reciprocal_ranks = [] | |
| for gt_tag in gt_cat_tags: | |
| if gt_tag in suggestion_ranks: | |
| rank = suggestion_ranks[gt_tag] | |
| reciprocal_ranks.append(1.0 / (rank + 1)) # +1 because rank is 0-indexed | |
| # Accumulate for averaging later | |
| cat_metric.recall_at_k += recall_at_k_count / len(gt_cat_tags) if gt_cat_tags else 0 | |
| cat_metric.precision_at_k += precision_at_k_count / min(k, len(cat_suggestions)) if cat_suggestions else 0 | |
| if reciprocal_ranks: | |
| cat_metric.mrr += sum(reciprocal_ranks) / len(reciprocal_ranks) | |
| cat_metric.mrr_count += 1 | |
| return metrics | |
| def print_category_metrics( | |
| metrics: Dict[str, CategoryMetrics], | |
| categories: Dict[str, TagCategory], | |
| n_samples: int, | |
| k: int, | |
| ): | |
| """ | |
| Print metrics organized by importance. | |
| Args: | |
| metrics: Category metrics | |
| categories: Category definitions | |
| n_samples: Number of samples evaluated | |
| k: Top-K for ranking metrics | |
| """ | |
| # Group by importance level | |
| by_importance = defaultdict(list) | |
| for cat_name, cat_metric in metrics.items(): | |
| by_importance[cat_metric.importance].append(cat_metric) | |
| # Print in order of importance | |
| for importance in sorted(by_importance.keys()): | |
| label = IMPORTANCE_LABELS.get(importance, "OTHER") | |
| print(f"\n{'='*80}") | |
| print(f"{label} CATEGORIES") | |
| print('='*80) | |
| cat_metrics = by_importance[importance] | |
| cat_metrics.sort(key=lambda m: m.category_name) | |
| for cat_metric in cat_metrics: | |
| category = categories[cat_metric.category_name] | |
| print(f"\n{cat_metric.display_name} ({cat_metric.category_name})") | |
| print(f" Constraint: {category.constraint.value}") | |
| print(f" Ground truth tags: {cat_metric.total_gt_tags}") | |
| # Binary prediction metrics | |
| if category.constraint.value == "exactly_one": | |
| print(f" Accuracy: {cat_metric.accuracy:.3f}") | |
| print(f" Precision: {cat_metric.precision:.3f}") | |
| print(f" Recall: {cat_metric.recall:.3f}") | |
| print(f" F1: {cat_metric.f1:.3f}") | |
| # Ranking metrics (averaged across samples) | |
| if cat_metric.mrr_count > 0: | |
| avg_recall_at_k = cat_metric.recall_at_k / n_samples if n_samples > 0 else 0 | |
| avg_precision_at_k = cat_metric.precision_at_k / n_samples if n_samples > 0 else 0 | |
| avg_mrr = cat_metric.mrr / cat_metric.mrr_count | |
| print(f" Recall@{k}: {avg_recall_at_k:.3f} (GT tags found in top-{k})") | |
| print(f" Precision@{k}: {avg_precision_at_k:.3f} (top-{k} that are correct)") | |
| print(f" MRR: {avg_mrr:.3f} (mean reciprocal rank)") | |
| print(f" Coverage: {cat_metric.found_in_suggestions}/{cat_metric.total_gt_tags} (GT tags in suggestions)") | |
| # Show raw counts for debugging | |
| print(f" (TP={cat_metric.tp}, FP={cat_metric.fp}, FN={cat_metric.fn}, TN={cat_metric.tn})") | |
| print(f"\n{'='*80}") | |
| print("SUMMARY") | |
| print('='*80) | |
| # Aggregate by importance level | |
| for importance in sorted(by_importance.keys()): | |
| label = IMPORTANCE_LABELS.get(importance, "OTHER") | |
| cat_metrics = by_importance[importance] | |
| total_tp = sum(m.tp for m in cat_metrics) | |
| total_fp = sum(m.fp for m in cat_metrics) | |
| total_fn = sum(m.fn for m in cat_metrics) | |
| total_gt = sum(m.total_gt_tags for m in cat_metrics) | |
| # Micro-averaged metrics (aggregate then calculate) | |
| micro_p = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0 | |
| micro_r = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0 | |
| micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) if (micro_p + micro_r) > 0 else 0 | |
| # Macro-averaged metrics (average across categories) | |
| precisions = [m.precision for m in cat_metrics if m.tp + m.fp > 0] | |
| recalls = [m.recall for m in cat_metrics if m.tp + m.fn > 0] | |
| f1s = [m.f1 for m in cat_metrics if m.tp + m.fn > 0] | |
| macro_p = sum(precisions) / len(precisions) if precisions else 0 | |
| macro_r = sum(recalls) / len(recalls) if recalls else 0 | |
| macro_f1 = sum(f1s) / len(f1s) if f1s else 0 | |
| print(f"\n{label}:") | |
| print(f" Total GT tags: {total_gt}") | |
| print(f" Micro-avg P/R/F1: {micro_p:.3f} / {micro_r:.3f} / {micro_f1:.3f}") | |
| print(f" Macro-avg P/R/F1: {macro_p:.3f} / {macro_r:.3f} / {macro_f1:.3f}") | |
| print(f"\n{'='*80}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Compute per-category evaluation metrics" | |
| ) | |
| parser.add_argument( | |
| "--results", | |
| required=True, | |
| help="Path to eval results JSONL file from eval_pipeline.py" | |
| ) | |
| parser.add_argument( | |
| "--checklist", | |
| default=str(_REPO_ROOT / "tagging_checklist.txt"), | |
| help="Path to e621 tagging checklist" | |
| ) | |
| parser.add_argument( | |
| "--k", | |
| type=int, | |
| default=5, | |
| help="Top-K for ranking metrics (default: 5)" | |
| ) | |
| parser.add_argument( | |
| "--skip-rating", | |
| action="store_true", | |
| default=True, | |
| help="Skip rating category in evaluation (dataset is rating:safe only)" | |
| ) | |
| args = parser.parse_args() | |
| # Load category definitions | |
| checklist_path = Path(args.checklist) | |
| if not checklist_path.exists(): | |
| print(f"Error: Checklist not found at {checklist_path}") | |
| sys.exit(1) | |
| categories = parse_checklist(checklist_path) | |
| # Remove rating if requested | |
| if args.skip_rating and 'rating' in categories: | |
| del categories['rating'] | |
| tag_to_category = build_category_tag_index(categories) | |
| # Load eval results | |
| results_path = Path(args.results) | |
| if not results_path.exists(): | |
| print(f"Error: Results file not found at {results_path}") | |
| sys.exit(1) | |
| eval_results = [] | |
| with open(results_path, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| result = json.loads(line) | |
| # Skip metadata lines | |
| if not result.get('_meta', False): | |
| eval_results.append(result) | |
| print(f"Loaded {len(eval_results)} evaluation results from {results_path}") | |
| # Compute metrics | |
| metrics = compute_category_metrics( | |
| eval_results, | |
| categories, | |
| tag_to_category, | |
| k=args.k, | |
| ) | |
| # Print results | |
| print_category_metrics(metrics, categories, len(eval_results), args.k) | |
| if __name__ == "__main__": | |
| main() | |