"""Centroid-based category suggestions using reduced TF-IDF tag vectors. This script uses e621 checklist-documented categories as seed centroids, then scores uncategorized tags against those centroids. Outputs: - data/analysis/category_centroid_review.csv - data/analysis/category_centroid_summary.json Optional seed override file: - data/analysis/category_seed_overrides.csv """ from __future__ import annotations import csv import json from collections import defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple import numpy as np from psq_rag.retrieval.state import get_tag_counts, get_tfidf_tag_vectors from psq_rag.tagging.category_parser import parse_checklist _REPO_ROOT = Path(__file__).resolve().parents[1] _REGISTRY_PATH = _REPO_ROOT / "data" / "category_registry.csv" _CHECKLIST_PATH = _REPO_ROOT / "tagging_checklist.txt" _SEED_OVERRIDES_PATH = _REPO_ROOT / "data" / "analysis" / "category_seed_overrides.csv" _TAG_GROUPS_PATH = _REPO_ROOT / "data" / "tag_groups.json" _TAG_GROUP_MAP_PATH = _REPO_ROOT / "data" / "analysis" / "category_tag_group_map.csv" _OUT_REVIEW_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_review.csv" _OUT_SUMMARY_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_summary.json" # Conservative defaults: only auto-accept when assignment is clear. AUTO_SIM_MIN = 0.78 AUTO_MARGIN_MIN = 0.06 REVIEW_SIM_MIN = 0.65 REVIEW_MARGIN_MIN = 0.03 def _load_registry_rows(path: Path) -> List[Dict[str, str]]: with path.open("r", encoding="utf-8", newline="") as f: return list(csv.DictReader(f)) def _load_seed_overrides(path: Path) -> Dict[str, Set[str]]: if not path.is_file(): return {} overrides: Dict[str, Set[str]] = defaultdict(set) with path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if row.get("enabled", "1").strip() not in {"1", "true", "True"}: continue category = (row.get("category_name") or "").strip() tag = (row.get("tag") or "").strip() if category and tag: overrides[category].add(tag) return overrides def _write_seed_override_template(path: Path) -> None: if path.exists(): return path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(["category_name", "tag", "enabled", "seed_note"]) writer.writerow(["objects_props", "bed", "1", "example manual seed"]) writer.writerow(["background_composition", "indoors", "1", "example manual seed"]) writer.writerow(["pose_action_detail", "stretching", "1", "example manual seed"]) def _write_tag_group_map_template(path: Path) -> None: if path.exists(): return path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8", newline="") as f: writer = csv.writer(f) writer.writerow(["category_name", "tag_group", "enabled", "seed_note"]) writer.writerow(["clothing_detail", "clothes", "1", "e621 wiki tag group"]) writer.writerow(["expression_detail", "facial_expressions", "1", "e621 wiki tag group"]) writer.writerow(["objects_props", "food", "1", "e621 wiki tag group"]) writer.writerow(["pose_action_detail", "pose", "1", "e621 wiki tag group"]) def _seed_categories_from_checklist() -> Dict[str, Set[str]]: categories = parse_checklist(_CHECKLIST_PATH) return {name: set(cat.tags) for name, cat in categories.items()} def _seed_proposed_categories_from_registry(rows: List[Dict[str, str]], top_n: int = 12) -> Dict[str, Set[str]]: checklist_categories = set(_seed_categories_from_checklist().keys()) grouped: Dict[str, List[Tuple[str, int]]] = defaultdict(list) for row in rows: category = (row.get("category_name") or "").strip() tag = (row.get("tag") or "").strip() status = (row.get("category_status") or "").strip() if not tag or not category: continue if category in {"uncategorized_review", "nsfw_excluded"}: continue if category in checklist_categories: continue if status not in {"proposed_missing", "proposed"}: continue try: freq = int(row.get("tag_fluffyrock_count") or "0") except ValueError: freq = 0 grouped[category].append((tag, freq)) out: Dict[str, Set[str]] = {} for category, entries in grouped.items(): entries.sort(key=lambda x: x[1], reverse=True) out[category] = {tag for tag, _ in entries[:top_n]} return out def _seed_from_tag_groups(tag_groups_path: Path, map_path: Path) -> Tuple[Dict[str, Set[str]], int, Set[str]]: if not tag_groups_path.is_file() or not map_path.is_file(): return {}, 0, set() with tag_groups_path.open("r", encoding="utf-8") as f: tag_groups = json.load(f) added = 0 out: Dict[str, Set[str]] = defaultdict(set) ignored_wiki_groups: Set[str] = set() with map_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if row.get("enabled", "1").strip() not in {"1", "true", "True"}: continue category = (row.get("category_name") or "").strip() group = (row.get("tag_group") or "").strip() if not category or not group: continue if category.lower().startswith("ignored_"): ignored_wiki_groups.add(group) continue members = tag_groups.get(group, []) if not isinstance(members, list): continue for tag in members: if isinstance(tag, str) and tag: out[category].add(tag) added += 1 return out, added, ignored_wiki_groups def _build_centroids( seed_sets: Dict[str, Set[str]], tag_to_row: Dict[str, int], vectors_norm: np.ndarray, ) -> Tuple[Dict[str, np.ndarray], Dict[str, int]]: centroids: Dict[str, np.ndarray] = {} seed_sizes: Dict[str, int] = {} for category, seeds in seed_sets.items(): idxs = [tag_to_row[tag] for tag in seeds if tag in tag_to_row] if len(idxs) < 2: continue mat = vectors_norm[idxs] centroid = mat.mean(axis=0) norm = np.linalg.norm(centroid) if norm == 0: continue centroids[category] = centroid / norm seed_sizes[category] = len(idxs) return centroids, seed_sizes def _candidate_tags(rows: List[Dict[str, str]]) -> List[Tuple[str, int]]: seen: Set[str] = set() candidates: List[Tuple[str, int]] = [] for row in rows: category = (row.get("category_name") or "").strip() tag = (row.get("tag") or "").strip() if category != "uncategorized_review" or not tag or tag in seen: continue seen.add(tag) try: freq = int(row.get("tag_fluffyrock_count") or "0") except ValueError: freq = 0 candidates.append((tag, freq)) return candidates def _decision(top_sim: float, margin: float) -> str: if top_sim >= AUTO_SIM_MIN and margin >= AUTO_MARGIN_MIN: return "auto_accept" if top_sim >= REVIEW_SIM_MIN and margin >= REVIEW_MARGIN_MIN: return "needs_review" return "hold" def main() -> None: if not _REGISTRY_PATH.is_file(): raise FileNotFoundError(f"Missing registry file: {_REGISTRY_PATH}") if not _CHECKLIST_PATH.is_file(): raise FileNotFoundError(f"Missing checklist file: {_CHECKLIST_PATH}") _write_seed_override_template(_SEED_OVERRIDES_PATH) _write_tag_group_map_template(_TAG_GROUP_MAP_PATH) rows = _load_registry_rows(_REGISTRY_PATH) seed_sets = _seed_categories_from_checklist() provisional = _seed_proposed_categories_from_registry(rows) for category, tags in provisional.items(): seed_sets.setdefault(category, set()).update(tags) overrides = _load_seed_overrides(_SEED_OVERRIDES_PATH) for category, tags in overrides.items(): seed_sets.setdefault(category, set()).update(tags) tag_group_seeds, n_tag_group_seeds, ignored_wiki_groups = _seed_from_tag_groups(_TAG_GROUPS_PATH, _TAG_GROUP_MAP_PATH) for category, tags in tag_group_seeds.items(): seed_sets.setdefault(category, set()).update(tags) vectors = get_tfidf_tag_vectors() vectors_norm = vectors["reduced_matrix_norm"] tag_to_row = vectors["tag_to_row_index"] centroids, seed_sizes = _build_centroids(seed_sets, tag_to_row, vectors_norm) if not centroids: raise RuntimeError("No category centroids created. Check seeds and vector availability.") category_names = sorted(centroids.keys()) centroid_matrix = np.stack([centroids[name] for name in category_names], axis=0) counts = get_tag_counts() candidates = _candidate_tags(rows) review_rows: List[Dict[str, str]] = [] bucket_counts = defaultdict(int) for tag, fallback_freq in candidates: idx = tag_to_row.get(tag) if idx is None: continue sims = centroid_matrix @ vectors_norm[idx] if sims.size == 0: continue order = np.argsort(sims)[::-1] top_i = int(order[0]) top2_i = int(order[1]) if sims.size > 1 else top_i top_sim = float(sims[top_i]) second_sim = float(sims[top2_i]) margin = top_sim - second_sim decision = _decision(top_sim, margin) bucket_counts[decision] += 1 review_rows.append( { "tag": tag, "fluffyrock_count": str(counts.get(tag, fallback_freq)), "best_category": category_names[top_i], "best_sim": f"{top_sim:.6f}", "second_category": category_names[top2_i], "second_sim": f"{second_sim:.6f}", "margin": f"{margin:.6f}", "decision": decision, } ) review_rows.sort( key=lambda r: ( {"auto_accept": 0, "needs_review": 1, "hold": 2}[r["decision"]], -int(r["fluffyrock_count"]), -float(r["best_sim"]), ) ) _OUT_REVIEW_PATH.parent.mkdir(parents=True, exist_ok=True) with _OUT_REVIEW_PATH.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "tag", "fluffyrock_count", "best_category", "best_sim", "second_category", "second_sim", "margin", "decision", ], ) writer.writeheader() writer.writerows(review_rows) centroid_overlap = [] for i, c1 in enumerate(category_names): for j in range(i + 1, len(category_names)): c2 = category_names[j] sim = float(np.dot(centroids[c1], centroids[c2])) if sim >= 0.70: centroid_overlap.append({"category_a": c1, "category_b": c2, "centroid_sim": round(sim, 4)}) centroid_overlap.sort(key=lambda x: x["centroid_sim"], reverse=True) bridge_tags = [ r for r in review_rows if float(r["best_sim"]) >= 0.70 and float(r["margin"]) < 0.02 ] bridge_tags = sorted(bridge_tags, key=lambda r: -int(r["fluffyrock_count"]))[:80] summary = { "registry_file": str(_REGISTRY_PATH), "checklist_file": str(_CHECKLIST_PATH), "seed_override_file": str(_SEED_OVERRIDES_PATH), "thresholds": { "auto_sim_min": AUTO_SIM_MIN, "auto_margin_min": AUTO_MARGIN_MIN, "review_sim_min": REVIEW_SIM_MIN, "review_margin_min": REVIEW_MARGIN_MIN, }, "n_centroids": len(category_names), "tag_group_seed_count": n_tag_group_seeds, "ignored_wiki_groups": sorted(ignored_wiki_groups), "tag_groups_file": str(_TAG_GROUPS_PATH), "tag_group_map_file": str(_TAG_GROUP_MAP_PATH), "seed_sizes": seed_sizes, "n_candidates": len(candidates), "bucket_counts": dict(bucket_counts), "high_overlap_centroid_pairs": centroid_overlap[:40], "bridge_tags_low_margin_high_sim": bridge_tags, "outputs": { "review_csv": str(_OUT_REVIEW_PATH), "summary_json": str(_OUT_SUMMARY_PATH), }, } with _OUT_SUMMARY_PATH.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print(f"Centroids built: {len(category_names)}") print(f"Candidate tags scored: {len(candidates)}") print(f"Decision buckets: {dict(bucket_counts)}") print(f"Review CSV: {_OUT_REVIEW_PATH}") print(f"Summary JSON: {_OUT_SUMMARY_PATH}") if __name__ == "__main__": main()