Prompt_Squirrel_RAG / scripts /analyze_category_centroids.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Centroid-based category suggestions using reduced TF-IDF tag vectors.
This script uses e621 checklist-documented categories as seed centroids,
then scores uncategorized tags against those centroids.
Outputs:
- data/analysis/category_centroid_review.csv
- data/analysis/category_centroid_summary.json
Optional seed override file:
- data/analysis/category_seed_overrides.csv
"""
from __future__ import annotations
import csv
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
import numpy as np
from psq_rag.retrieval.state import get_tag_counts, get_tfidf_tag_vectors
from psq_rag.tagging.category_parser import parse_checklist
_REPO_ROOT = Path(__file__).resolve().parents[1]
_REGISTRY_PATH = _REPO_ROOT / "data" / "category_registry.csv"
_CHECKLIST_PATH = _REPO_ROOT / "tagging_checklist.txt"
_SEED_OVERRIDES_PATH = _REPO_ROOT / "data" / "analysis" / "category_seed_overrides.csv"
_TAG_GROUPS_PATH = _REPO_ROOT / "data" / "tag_groups.json"
_TAG_GROUP_MAP_PATH = _REPO_ROOT / "data" / "analysis" / "category_tag_group_map.csv"
_OUT_REVIEW_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_review.csv"
_OUT_SUMMARY_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_summary.json"
# Conservative defaults: only auto-accept when assignment is clear.
AUTO_SIM_MIN = 0.78
AUTO_MARGIN_MIN = 0.06
REVIEW_SIM_MIN = 0.65
REVIEW_MARGIN_MIN = 0.03
def _load_registry_rows(path: Path) -> List[Dict[str, str]]:
with path.open("r", encoding="utf-8", newline="") as f:
return list(csv.DictReader(f))
def _load_seed_overrides(path: Path) -> Dict[str, Set[str]]:
if not path.is_file():
return {}
overrides: Dict[str, Set[str]] = defaultdict(set)
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("enabled", "1").strip() not in {"1", "true", "True"}:
continue
category = (row.get("category_name") or "").strip()
tag = (row.get("tag") or "").strip()
if category and tag:
overrides[category].add(tag)
return overrides
def _write_seed_override_template(path: Path) -> None:
if path.exists():
return
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["category_name", "tag", "enabled", "seed_note"])
writer.writerow(["objects_props", "bed", "1", "example manual seed"])
writer.writerow(["background_composition", "indoors", "1", "example manual seed"])
writer.writerow(["pose_action_detail", "stretching", "1", "example manual seed"])
def _write_tag_group_map_template(path: Path) -> None:
if path.exists():
return
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["category_name", "tag_group", "enabled", "seed_note"])
writer.writerow(["clothing_detail", "clothes", "1", "e621 wiki tag group"])
writer.writerow(["expression_detail", "facial_expressions", "1", "e621 wiki tag group"])
writer.writerow(["objects_props", "food", "1", "e621 wiki tag group"])
writer.writerow(["pose_action_detail", "pose", "1", "e621 wiki tag group"])
def _seed_categories_from_checklist() -> Dict[str, Set[str]]:
categories = parse_checklist(_CHECKLIST_PATH)
return {name: set(cat.tags) for name, cat in categories.items()}
def _seed_proposed_categories_from_registry(rows: List[Dict[str, str]], top_n: int = 12) -> Dict[str, Set[str]]:
checklist_categories = set(_seed_categories_from_checklist().keys())
grouped: Dict[str, List[Tuple[str, int]]] = defaultdict(list)
for row in rows:
category = (row.get("category_name") or "").strip()
tag = (row.get("tag") or "").strip()
status = (row.get("category_status") or "").strip()
if not tag or not category:
continue
if category in {"uncategorized_review", "nsfw_excluded"}:
continue
if category in checklist_categories:
continue
if status not in {"proposed_missing", "proposed"}:
continue
try:
freq = int(row.get("tag_fluffyrock_count") or "0")
except ValueError:
freq = 0
grouped[category].append((tag, freq))
out: Dict[str, Set[str]] = {}
for category, entries in grouped.items():
entries.sort(key=lambda x: x[1], reverse=True)
out[category] = {tag for tag, _ in entries[:top_n]}
return out
def _seed_from_tag_groups(tag_groups_path: Path, map_path: Path) -> Tuple[Dict[str, Set[str]], int, Set[str]]:
if not tag_groups_path.is_file() or not map_path.is_file():
return {}, 0, set()
with tag_groups_path.open("r", encoding="utf-8") as f:
tag_groups = json.load(f)
added = 0
out: Dict[str, Set[str]] = defaultdict(set)
ignored_wiki_groups: Set[str] = set()
with map_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("enabled", "1").strip() not in {"1", "true", "True"}:
continue
category = (row.get("category_name") or "").strip()
group = (row.get("tag_group") or "").strip()
if not category or not group:
continue
if category.lower().startswith("ignored_"):
ignored_wiki_groups.add(group)
continue
members = tag_groups.get(group, [])
if not isinstance(members, list):
continue
for tag in members:
if isinstance(tag, str) and tag:
out[category].add(tag)
added += 1
return out, added, ignored_wiki_groups
def _build_centroids(
seed_sets: Dict[str, Set[str]],
tag_to_row: Dict[str, int],
vectors_norm: np.ndarray,
) -> Tuple[Dict[str, np.ndarray], Dict[str, int]]:
centroids: Dict[str, np.ndarray] = {}
seed_sizes: Dict[str, int] = {}
for category, seeds in seed_sets.items():
idxs = [tag_to_row[tag] for tag in seeds if tag in tag_to_row]
if len(idxs) < 2:
continue
mat = vectors_norm[idxs]
centroid = mat.mean(axis=0)
norm = np.linalg.norm(centroid)
if norm == 0:
continue
centroids[category] = centroid / norm
seed_sizes[category] = len(idxs)
return centroids, seed_sizes
def _candidate_tags(rows: List[Dict[str, str]]) -> List[Tuple[str, int]]:
seen: Set[str] = set()
candidates: List[Tuple[str, int]] = []
for row in rows:
category = (row.get("category_name") or "").strip()
tag = (row.get("tag") or "").strip()
if category != "uncategorized_review" or not tag or tag in seen:
continue
seen.add(tag)
try:
freq = int(row.get("tag_fluffyrock_count") or "0")
except ValueError:
freq = 0
candidates.append((tag, freq))
return candidates
def _decision(top_sim: float, margin: float) -> str:
if top_sim >= AUTO_SIM_MIN and margin >= AUTO_MARGIN_MIN:
return "auto_accept"
if top_sim >= REVIEW_SIM_MIN and margin >= REVIEW_MARGIN_MIN:
return "needs_review"
return "hold"
def main() -> None:
if not _REGISTRY_PATH.is_file():
raise FileNotFoundError(f"Missing registry file: {_REGISTRY_PATH}")
if not _CHECKLIST_PATH.is_file():
raise FileNotFoundError(f"Missing checklist file: {_CHECKLIST_PATH}")
_write_seed_override_template(_SEED_OVERRIDES_PATH)
_write_tag_group_map_template(_TAG_GROUP_MAP_PATH)
rows = _load_registry_rows(_REGISTRY_PATH)
seed_sets = _seed_categories_from_checklist()
provisional = _seed_proposed_categories_from_registry(rows)
for category, tags in provisional.items():
seed_sets.setdefault(category, set()).update(tags)
overrides = _load_seed_overrides(_SEED_OVERRIDES_PATH)
for category, tags in overrides.items():
seed_sets.setdefault(category, set()).update(tags)
tag_group_seeds, n_tag_group_seeds, ignored_wiki_groups = _seed_from_tag_groups(_TAG_GROUPS_PATH, _TAG_GROUP_MAP_PATH)
for category, tags in tag_group_seeds.items():
seed_sets.setdefault(category, set()).update(tags)
vectors = get_tfidf_tag_vectors()
vectors_norm = vectors["reduced_matrix_norm"]
tag_to_row = vectors["tag_to_row_index"]
centroids, seed_sizes = _build_centroids(seed_sets, tag_to_row, vectors_norm)
if not centroids:
raise RuntimeError("No category centroids created. Check seeds and vector availability.")
category_names = sorted(centroids.keys())
centroid_matrix = np.stack([centroids[name] for name in category_names], axis=0)
counts = get_tag_counts()
candidates = _candidate_tags(rows)
review_rows: List[Dict[str, str]] = []
bucket_counts = defaultdict(int)
for tag, fallback_freq in candidates:
idx = tag_to_row.get(tag)
if idx is None:
continue
sims = centroid_matrix @ vectors_norm[idx]
if sims.size == 0:
continue
order = np.argsort(sims)[::-1]
top_i = int(order[0])
top2_i = int(order[1]) if sims.size > 1 else top_i
top_sim = float(sims[top_i])
second_sim = float(sims[top2_i])
margin = top_sim - second_sim
decision = _decision(top_sim, margin)
bucket_counts[decision] += 1
review_rows.append(
{
"tag": tag,
"fluffyrock_count": str(counts.get(tag, fallback_freq)),
"best_category": category_names[top_i],
"best_sim": f"{top_sim:.6f}",
"second_category": category_names[top2_i],
"second_sim": f"{second_sim:.6f}",
"margin": f"{margin:.6f}",
"decision": decision,
}
)
review_rows.sort(
key=lambda r: (
{"auto_accept": 0, "needs_review": 1, "hold": 2}[r["decision"]],
-int(r["fluffyrock_count"]),
-float(r["best_sim"]),
)
)
_OUT_REVIEW_PATH.parent.mkdir(parents=True, exist_ok=True)
with _OUT_REVIEW_PATH.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"fluffyrock_count",
"best_category",
"best_sim",
"second_category",
"second_sim",
"margin",
"decision",
],
)
writer.writeheader()
writer.writerows(review_rows)
centroid_overlap = []
for i, c1 in enumerate(category_names):
for j in range(i + 1, len(category_names)):
c2 = category_names[j]
sim = float(np.dot(centroids[c1], centroids[c2]))
if sim >= 0.70:
centroid_overlap.append({"category_a": c1, "category_b": c2, "centroid_sim": round(sim, 4)})
centroid_overlap.sort(key=lambda x: x["centroid_sim"], reverse=True)
bridge_tags = [
r
for r in review_rows
if float(r["best_sim"]) >= 0.70 and float(r["margin"]) < 0.02
]
bridge_tags = sorted(bridge_tags, key=lambda r: -int(r["fluffyrock_count"]))[:80]
summary = {
"registry_file": str(_REGISTRY_PATH),
"checklist_file": str(_CHECKLIST_PATH),
"seed_override_file": str(_SEED_OVERRIDES_PATH),
"thresholds": {
"auto_sim_min": AUTO_SIM_MIN,
"auto_margin_min": AUTO_MARGIN_MIN,
"review_sim_min": REVIEW_SIM_MIN,
"review_margin_min": REVIEW_MARGIN_MIN,
},
"n_centroids": len(category_names),
"tag_group_seed_count": n_tag_group_seeds,
"ignored_wiki_groups": sorted(ignored_wiki_groups),
"tag_groups_file": str(_TAG_GROUPS_PATH),
"tag_group_map_file": str(_TAG_GROUP_MAP_PATH),
"seed_sizes": seed_sizes,
"n_candidates": len(candidates),
"bucket_counts": dict(bucket_counts),
"high_overlap_centroid_pairs": centroid_overlap[:40],
"bridge_tags_low_margin_high_sim": bridge_tags,
"outputs": {
"review_csv": str(_OUT_REVIEW_PATH),
"summary_json": str(_OUT_SUMMARY_PATH),
},
}
with _OUT_SUMMARY_PATH.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print(f"Centroids built: {len(category_names)}")
print(f"Candidate tags scored: {len(candidates)}")
print(f"Decision buckets: {dict(bucket_counts)}")
print(f"Review CSV: {_OUT_REVIEW_PATH}")
print(f"Summary JSON: {_OUT_SUMMARY_PATH}")
if __name__ == "__main__":
main()