Spaces:
Running
Running
| """Centroid-based category suggestions using reduced TF-IDF tag vectors. | |
| This script uses e621 checklist-documented categories as seed centroids, | |
| then scores uncategorized tags against those centroids. | |
| Outputs: | |
| - data/analysis/category_centroid_review.csv | |
| - data/analysis/category_centroid_summary.json | |
| Optional seed override file: | |
| - data/analysis/category_seed_overrides.csv | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Dict, List, Set, Tuple | |
| import numpy as np | |
| from psq_rag.retrieval.state import get_tag_counts, get_tfidf_tag_vectors | |
| from psq_rag.tagging.category_parser import parse_checklist | |
| _REPO_ROOT = Path(__file__).resolve().parents[1] | |
| _REGISTRY_PATH = _REPO_ROOT / "data" / "category_registry.csv" | |
| _CHECKLIST_PATH = _REPO_ROOT / "tagging_checklist.txt" | |
| _SEED_OVERRIDES_PATH = _REPO_ROOT / "data" / "analysis" / "category_seed_overrides.csv" | |
| _TAG_GROUPS_PATH = _REPO_ROOT / "data" / "tag_groups.json" | |
| _TAG_GROUP_MAP_PATH = _REPO_ROOT / "data" / "analysis" / "category_tag_group_map.csv" | |
| _OUT_REVIEW_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_review.csv" | |
| _OUT_SUMMARY_PATH = _REPO_ROOT / "data" / "analysis" / "category_centroid_summary.json" | |
| # Conservative defaults: only auto-accept when assignment is clear. | |
| AUTO_SIM_MIN = 0.78 | |
| AUTO_MARGIN_MIN = 0.06 | |
| REVIEW_SIM_MIN = 0.65 | |
| REVIEW_MARGIN_MIN = 0.03 | |
| def _load_registry_rows(path: Path) -> List[Dict[str, str]]: | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| return list(csv.DictReader(f)) | |
| def _load_seed_overrides(path: Path) -> Dict[str, Set[str]]: | |
| if not path.is_file(): | |
| return {} | |
| overrides: Dict[str, Set[str]] = defaultdict(set) | |
| with path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if row.get("enabled", "1").strip() not in {"1", "true", "True"}: | |
| continue | |
| category = (row.get("category_name") or "").strip() | |
| tag = (row.get("tag") or "").strip() | |
| if category and tag: | |
| overrides[category].add(tag) | |
| return overrides | |
| def _write_seed_override_template(path: Path) -> None: | |
| if path.exists(): | |
| return | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["category_name", "tag", "enabled", "seed_note"]) | |
| writer.writerow(["objects_props", "bed", "1", "example manual seed"]) | |
| writer.writerow(["background_composition", "indoors", "1", "example manual seed"]) | |
| writer.writerow(["pose_action_detail", "stretching", "1", "example manual seed"]) | |
| def _write_tag_group_map_template(path: Path) -> None: | |
| if path.exists(): | |
| return | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["category_name", "tag_group", "enabled", "seed_note"]) | |
| writer.writerow(["clothing_detail", "clothes", "1", "e621 wiki tag group"]) | |
| writer.writerow(["expression_detail", "facial_expressions", "1", "e621 wiki tag group"]) | |
| writer.writerow(["objects_props", "food", "1", "e621 wiki tag group"]) | |
| writer.writerow(["pose_action_detail", "pose", "1", "e621 wiki tag group"]) | |
| def _seed_categories_from_checklist() -> Dict[str, Set[str]]: | |
| categories = parse_checklist(_CHECKLIST_PATH) | |
| return {name: set(cat.tags) for name, cat in categories.items()} | |
| def _seed_proposed_categories_from_registry(rows: List[Dict[str, str]], top_n: int = 12) -> Dict[str, Set[str]]: | |
| checklist_categories = set(_seed_categories_from_checklist().keys()) | |
| grouped: Dict[str, List[Tuple[str, int]]] = defaultdict(list) | |
| for row in rows: | |
| category = (row.get("category_name") or "").strip() | |
| tag = (row.get("tag") or "").strip() | |
| status = (row.get("category_status") or "").strip() | |
| if not tag or not category: | |
| continue | |
| if category in {"uncategorized_review", "nsfw_excluded"}: | |
| continue | |
| if category in checklist_categories: | |
| continue | |
| if status not in {"proposed_missing", "proposed"}: | |
| continue | |
| try: | |
| freq = int(row.get("tag_fluffyrock_count") or "0") | |
| except ValueError: | |
| freq = 0 | |
| grouped[category].append((tag, freq)) | |
| out: Dict[str, Set[str]] = {} | |
| for category, entries in grouped.items(): | |
| entries.sort(key=lambda x: x[1], reverse=True) | |
| out[category] = {tag for tag, _ in entries[:top_n]} | |
| return out | |
| def _seed_from_tag_groups(tag_groups_path: Path, map_path: Path) -> Tuple[Dict[str, Set[str]], int, Set[str]]: | |
| if not tag_groups_path.is_file() or not map_path.is_file(): | |
| return {}, 0, set() | |
| with tag_groups_path.open("r", encoding="utf-8") as f: | |
| tag_groups = json.load(f) | |
| added = 0 | |
| out: Dict[str, Set[str]] = defaultdict(set) | |
| ignored_wiki_groups: Set[str] = set() | |
| with map_path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| if row.get("enabled", "1").strip() not in {"1", "true", "True"}: | |
| continue | |
| category = (row.get("category_name") or "").strip() | |
| group = (row.get("tag_group") or "").strip() | |
| if not category or not group: | |
| continue | |
| if category.lower().startswith("ignored_"): | |
| ignored_wiki_groups.add(group) | |
| continue | |
| members = tag_groups.get(group, []) | |
| if not isinstance(members, list): | |
| continue | |
| for tag in members: | |
| if isinstance(tag, str) and tag: | |
| out[category].add(tag) | |
| added += 1 | |
| return out, added, ignored_wiki_groups | |
| def _build_centroids( | |
| seed_sets: Dict[str, Set[str]], | |
| tag_to_row: Dict[str, int], | |
| vectors_norm: np.ndarray, | |
| ) -> Tuple[Dict[str, np.ndarray], Dict[str, int]]: | |
| centroids: Dict[str, np.ndarray] = {} | |
| seed_sizes: Dict[str, int] = {} | |
| for category, seeds in seed_sets.items(): | |
| idxs = [tag_to_row[tag] for tag in seeds if tag in tag_to_row] | |
| if len(idxs) < 2: | |
| continue | |
| mat = vectors_norm[idxs] | |
| centroid = mat.mean(axis=0) | |
| norm = np.linalg.norm(centroid) | |
| if norm == 0: | |
| continue | |
| centroids[category] = centroid / norm | |
| seed_sizes[category] = len(idxs) | |
| return centroids, seed_sizes | |
| def _candidate_tags(rows: List[Dict[str, str]]) -> List[Tuple[str, int]]: | |
| seen: Set[str] = set() | |
| candidates: List[Tuple[str, int]] = [] | |
| for row in rows: | |
| category = (row.get("category_name") or "").strip() | |
| tag = (row.get("tag") or "").strip() | |
| if category != "uncategorized_review" or not tag or tag in seen: | |
| continue | |
| seen.add(tag) | |
| try: | |
| freq = int(row.get("tag_fluffyrock_count") or "0") | |
| except ValueError: | |
| freq = 0 | |
| candidates.append((tag, freq)) | |
| return candidates | |
| def _decision(top_sim: float, margin: float) -> str: | |
| if top_sim >= AUTO_SIM_MIN and margin >= AUTO_MARGIN_MIN: | |
| return "auto_accept" | |
| if top_sim >= REVIEW_SIM_MIN and margin >= REVIEW_MARGIN_MIN: | |
| return "needs_review" | |
| return "hold" | |
| def main() -> None: | |
| if not _REGISTRY_PATH.is_file(): | |
| raise FileNotFoundError(f"Missing registry file: {_REGISTRY_PATH}") | |
| if not _CHECKLIST_PATH.is_file(): | |
| raise FileNotFoundError(f"Missing checklist file: {_CHECKLIST_PATH}") | |
| _write_seed_override_template(_SEED_OVERRIDES_PATH) | |
| _write_tag_group_map_template(_TAG_GROUP_MAP_PATH) | |
| rows = _load_registry_rows(_REGISTRY_PATH) | |
| seed_sets = _seed_categories_from_checklist() | |
| provisional = _seed_proposed_categories_from_registry(rows) | |
| for category, tags in provisional.items(): | |
| seed_sets.setdefault(category, set()).update(tags) | |
| overrides = _load_seed_overrides(_SEED_OVERRIDES_PATH) | |
| for category, tags in overrides.items(): | |
| seed_sets.setdefault(category, set()).update(tags) | |
| tag_group_seeds, n_tag_group_seeds, ignored_wiki_groups = _seed_from_tag_groups(_TAG_GROUPS_PATH, _TAG_GROUP_MAP_PATH) | |
| for category, tags in tag_group_seeds.items(): | |
| seed_sets.setdefault(category, set()).update(tags) | |
| vectors = get_tfidf_tag_vectors() | |
| vectors_norm = vectors["reduced_matrix_norm"] | |
| tag_to_row = vectors["tag_to_row_index"] | |
| centroids, seed_sizes = _build_centroids(seed_sets, tag_to_row, vectors_norm) | |
| if not centroids: | |
| raise RuntimeError("No category centroids created. Check seeds and vector availability.") | |
| category_names = sorted(centroids.keys()) | |
| centroid_matrix = np.stack([centroids[name] for name in category_names], axis=0) | |
| counts = get_tag_counts() | |
| candidates = _candidate_tags(rows) | |
| review_rows: List[Dict[str, str]] = [] | |
| bucket_counts = defaultdict(int) | |
| for tag, fallback_freq in candidates: | |
| idx = tag_to_row.get(tag) | |
| if idx is None: | |
| continue | |
| sims = centroid_matrix @ vectors_norm[idx] | |
| if sims.size == 0: | |
| continue | |
| order = np.argsort(sims)[::-1] | |
| top_i = int(order[0]) | |
| top2_i = int(order[1]) if sims.size > 1 else top_i | |
| top_sim = float(sims[top_i]) | |
| second_sim = float(sims[top2_i]) | |
| margin = top_sim - second_sim | |
| decision = _decision(top_sim, margin) | |
| bucket_counts[decision] += 1 | |
| review_rows.append( | |
| { | |
| "tag": tag, | |
| "fluffyrock_count": str(counts.get(tag, fallback_freq)), | |
| "best_category": category_names[top_i], | |
| "best_sim": f"{top_sim:.6f}", | |
| "second_category": category_names[top2_i], | |
| "second_sim": f"{second_sim:.6f}", | |
| "margin": f"{margin:.6f}", | |
| "decision": decision, | |
| } | |
| ) | |
| review_rows.sort( | |
| key=lambda r: ( | |
| {"auto_accept": 0, "needs_review": 1, "hold": 2}[r["decision"]], | |
| -int(r["fluffyrock_count"]), | |
| -float(r["best_sim"]), | |
| ) | |
| ) | |
| _OUT_REVIEW_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with _OUT_REVIEW_PATH.open("w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter( | |
| f, | |
| fieldnames=[ | |
| "tag", | |
| "fluffyrock_count", | |
| "best_category", | |
| "best_sim", | |
| "second_category", | |
| "second_sim", | |
| "margin", | |
| "decision", | |
| ], | |
| ) | |
| writer.writeheader() | |
| writer.writerows(review_rows) | |
| centroid_overlap = [] | |
| for i, c1 in enumerate(category_names): | |
| for j in range(i + 1, len(category_names)): | |
| c2 = category_names[j] | |
| sim = float(np.dot(centroids[c1], centroids[c2])) | |
| if sim >= 0.70: | |
| centroid_overlap.append({"category_a": c1, "category_b": c2, "centroid_sim": round(sim, 4)}) | |
| centroid_overlap.sort(key=lambda x: x["centroid_sim"], reverse=True) | |
| bridge_tags = [ | |
| r | |
| for r in review_rows | |
| if float(r["best_sim"]) >= 0.70 and float(r["margin"]) < 0.02 | |
| ] | |
| bridge_tags = sorted(bridge_tags, key=lambda r: -int(r["fluffyrock_count"]))[:80] | |
| summary = { | |
| "registry_file": str(_REGISTRY_PATH), | |
| "checklist_file": str(_CHECKLIST_PATH), | |
| "seed_override_file": str(_SEED_OVERRIDES_PATH), | |
| "thresholds": { | |
| "auto_sim_min": AUTO_SIM_MIN, | |
| "auto_margin_min": AUTO_MARGIN_MIN, | |
| "review_sim_min": REVIEW_SIM_MIN, | |
| "review_margin_min": REVIEW_MARGIN_MIN, | |
| }, | |
| "n_centroids": len(category_names), | |
| "tag_group_seed_count": n_tag_group_seeds, | |
| "ignored_wiki_groups": sorted(ignored_wiki_groups), | |
| "tag_groups_file": str(_TAG_GROUPS_PATH), | |
| "tag_group_map_file": str(_TAG_GROUP_MAP_PATH), | |
| "seed_sizes": seed_sizes, | |
| "n_candidates": len(candidates), | |
| "bucket_counts": dict(bucket_counts), | |
| "high_overlap_centroid_pairs": centroid_overlap[:40], | |
| "bridge_tags_low_margin_high_sim": bridge_tags, | |
| "outputs": { | |
| "review_csv": str(_OUT_REVIEW_PATH), | |
| "summary_json": str(_OUT_SUMMARY_PATH), | |
| }, | |
| } | |
| with _OUT_SUMMARY_PATH.open("w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| print(f"Centroids built: {len(category_names)}") | |
| print(f"Candidate tags scored: {len(candidates)}") | |
| print(f"Decision buckets: {dict(bucket_counts)}") | |
| print(f"Review CSV: {_OUT_REVIEW_PATH}") | |
| print(f"Summary JSON: {_OUT_SUMMARY_PATH}") | |
| if __name__ == "__main__": | |
| main() | |