"""Guided facet assignment for uncovered tags using vectors + lightweight rules. Purpose: - Assign high-frequency uncovered tags into semantically useful facets. - Avoid naive free clustering by using seeded centroids + lexical constraints. - Keep output compact: exactly two overwrite-in-place files. Outputs (overwritten each run): - data/analysis/guided_facet_assignments.csv - data/analysis/guided_facet_summary.json """ from __future__ import annotations import csv import json import re from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Set, Tuple import numpy as np from psq_rag.retrieval.state import get_tfidf_tag_vectors REPO = Path(__file__).resolve().parents[1] COUNTS_CSV = REPO / "fluffyrock_3m.csv" SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl" WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json" REGISTRY_CSV = REPO / "data" / "category_registry.csv" PROPOSAL_CSV = REPO / "data" / "analysis" / "category_expansion_proposal.csv" OUT_ASSIGN = REPO / "data" / "analysis" / "guided_facet_assignments.csv" OUT_SUMMARY = REPO / "data" / "analysis" / "guided_facet_summary.json" MIN_COUNT = 200 FACETS = { "species_taxonomy": { "seeds": { "canid", "canis", "felid", "felis", "equid", "leporid", "domestic_dog", "domestic_cat", "wolf", "fox", "bird", "bear", "unicorn", "dragon", "reptile", "bovid", "pony", "horse", }, "patterns": [r"canid|canis|felid|felis|equid|leporid|domestic_|wolf|fox|bird|bear|unicorn|dragon|reptile|bovid|pony|horse|pantherine"], "sim_min": 0.74, "margin_min": 0.03, }, "character_traits": { "seeds": { "young", "cub", "vein", "muscular", "slightly_chubby", "overweight", "curvy_figure", "thick_thighs", "wide_hips", "huge_breasts", "huge_butt", "abs", "pecs", }, "patterns": [r"young|cub|vein|muscular|chubby|overweight|curvy|thigh|hips|abs|pecs|belly|cleavage"], "sim_min": 0.73, "margin_min": 0.03, }, "clothing_coverage": { "seeds": {"topless", "bottomless", "barefoot", "panties", "thigh_highs", "stockings", "clothed"}, "patterns": [r"topless|bottomless|barefoot|panties|thigh_highs|stockings|underwear|nude"], "sim_min": 0.70, "margin_min": 0.02, }, "symbol_text_misc": { "seeds": {"<3", "text", "symbol", "emblem", "logo"}, "patterns": [r"^<3$", r"text|symbol|logo|emblem|heart"], "sim_min": 0.0, "margin_min": 0.0, }, "fluids_explicit_sensitive": { "seeds": {"bodily_fluids", "saliva", "sweat", "dripping", "cum", "nude", "nipples"}, "patterns": [r"fluid|saliva|sweat|drip|cum|nude|nipple|areola|bodily_fluids"], "sim_min": 0.68, "margin_min": 0.01, }, } # Light lexical boosts; still gated by thresholds for most facets. LEXICAL_BOOST = 0.08 def load_counts(path: Path) -> Dict[str, int]: counts: Dict[str, int] = {} with path.open("r", encoding="utf-8", newline="") as f: reader = csv.reader(f) for row in reader: if len(row) < 3: continue try: counts[row[0]] = int(row[2]) if row[2] else 0 except ValueError: counts[row[0]] = 0 return counts def load_sample_tag_occurrences(path: Path, counts: Dict[str, int], min_count: int) -> Counter: occ = Counter() with path.open("r", encoding="utf-8") as f: for line in f: obj = json.loads(line) raw = obj.get("tags_ground_truth_categorized", "") if not raw: continue try: categorized = json.loads(raw) except Exception: continue tags: Set[str] = set() if isinstance(categorized, dict): for vals in categorized.values(): if isinstance(vals, list): for t in vals: if isinstance(t, str) and counts.get(t, 0) >= min_count: tags.add(t) occ.update(tags) return occ def load_base_groups() -> Dict[str, Set[str]]: with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f: wiki = json.load(f) groups = {k: set(v) for k, v in wiki.items() if isinstance(v, list)} with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}: continue c = (row.get("category_name") or "").strip() t = (row.get("tag") or "").strip() if c and t: groups.setdefault(f"cat:{c}", set()).add(t) with PROPOSAL_CSV.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: if row.get("proposed_action") not in {"new_category", "merge_existing"}: continue tgt = (row.get("target_category") or "").strip() tag = (row.get("tag") or "").strip() if tgt and tag and tgt != "none": groups.setdefault(f"cat:{tgt}", set()).add(tag) return groups def build_centroids(tag_to_row: Dict[str, int], vectors_norm: np.ndarray) -> Dict[str, np.ndarray]: centroids: Dict[str, np.ndarray] = {} for facet, cfg in FACETS.items(): seed_idxs = [tag_to_row[t] for t in cfg["seeds"] if t in tag_to_row] if len(seed_idxs) < 2: continue mat = vectors_norm[seed_idxs] c = mat.mean(axis=0) n = np.linalg.norm(c) if n == 0: continue centroids[facet] = c / n return centroids def lexical_match_score(tag: str, facet: str) -> float: patterns = FACETS[facet]["patterns"] for p in patterns: if re.search(p, tag): return LEXICAL_BOOST return 0.0 def decision_for(tag: str, facet: str, sim: float, margin: float, lexical: float) -> str: # Symbol/text facet is mostly lexical by design. if facet == "symbol_text_misc": if lexical > 0.0 or re.search(r"[^a-z0-9_()/-]", tag): return "auto_assign" return "review" cfg = FACETS[facet] score = sim + lexical if score >= cfg["sim_min"] and margin >= cfg["margin_min"]: return "auto_assign" return "review" def coverage_pct(groups: Dict[str, Set[str]], tags: Set[str]) -> float: covered = sum(1 for t in tags if any(t in g for g in groups.values())) return round((covered / len(tags) * 100.0), 2) if tags else 0.0 def greedy_top15_pct(groups: Dict[str, Set[str]], occ: Counter) -> float: uncovered = Counter(occ) total = sum(occ.values()) covered = 0 chosen: Set[str] = set() for _ in range(15): best_g = None best_gain = 0 best_new = [] for g, tags in groups.items(): if g in chosen: continue gain = 0 new_tags = [] for t in tags: c = uncovered.get(t, 0) if c > 0: gain += c new_tags.append(t) if gain > best_gain: best_g = g best_gain = gain best_new = new_tags if not best_g or best_gain <= 0: break chosen.add(best_g) for t in best_new: uncovered[t] = 0 covered += best_gain return round((covered / total) * 100.0, 2) if total else 0.0 def main() -> None: counts = load_counts(COUNTS_CSV) occ = load_sample_tag_occurrences(SAMPLE_JSONL, counts, MIN_COUNT) all_tags = set(occ.keys()) base_groups = load_base_groups() covered_base = {t for t in all_tags if any(t in g for g in base_groups.values())} uncovered = sorted(all_tags - covered_base, key=lambda t: (counts.get(t, 0), occ[t]), reverse=True) vectors = get_tfidf_tag_vectors() vectors_norm = vectors["reduced_matrix_norm"] tag_to_row = vectors["tag_to_row_index"] centroids = build_centroids(tag_to_row, vectors_norm) facet_names = sorted(centroids.keys()) C = np.stack([centroids[f] for f in facet_names], axis=0) rows: List[Dict[str, str]] = [] action_counts = Counter() facet_auto_counts = Counter() for tag in uncovered: if tag not in tag_to_row: continue sims = C @ vectors_norm[tag_to_row[tag]] order = np.argsort(sims)[::-1] i1 = int(order[0]) i2 = int(order[1]) if sims.size > 1 else i1 best_facet = facet_names[i1] best_sim = float(sims[i1]) second_facet = facet_names[i2] second_sim = float(sims[i2]) margin = best_sim - second_sim lex = lexical_match_score(tag, best_facet) score = best_sim + lex decision = decision_for(tag, best_facet, best_sim, margin, lex) action_counts[decision] += 1 if decision == "auto_assign": facet_auto_counts[best_facet] += 1 rows.append( { "tag": tag, "fluffyrock_count": str(counts.get(tag, 0)), "sample_occurrences": str(occ[tag]), "best_facet": best_facet, "best_sim": f"{best_sim:.6f}", "lexical_boost": f"{lex:.2f}", "score": f"{score:.6f}", "second_facet": second_facet, "second_sim": f"{second_sim:.6f}", "margin": f"{margin:.6f}", "decision": decision, } ) rows.sort(key=lambda r: (r["decision"] != "auto_assign", -int(r["fluffyrock_count"]))) OUT_ASSIGN.parent.mkdir(parents=True, exist_ok=True) with OUT_ASSIGN.open("w", encoding="utf-8", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "tag", "fluffyrock_count", "sample_occurrences", "best_facet", "best_sim", "lexical_boost", "score", "second_facet", "second_sim", "margin", "decision", ], ) writer.writeheader() writer.writerows(rows) # Coverage projection, with and without explicit-sensitive facet enabled. projected_all = {k: set(v) for k, v in base_groups.items()} projected_no_explicit = {k: set(v) for k, v in base_groups.items()} facet_to_group = { "species_taxonomy": "cat:species_specific", "character_traits": "cat:character_traits", "clothing_coverage": "cat:clothing_detail", "symbol_text_misc": "cat:miscellaneous", "fluids_explicit_sensitive": "cat:explicit_sensitive", } for r in rows: if r["decision"] != "auto_assign": continue tag = r["tag"] facet = r["best_facet"] group_key = facet_to_group[facet] projected_all.setdefault(group_key, set()).add(tag) if facet != "fluids_explicit_sensitive": projected_no_explicit.setdefault(group_key, set()).add(tag) summary = { "min_count": MIN_COUNT, "n_unique_tags_considered": len(all_tags), "n_uncovered_before_guided_facets": len(uncovered), "facet_names": facet_names, "decision_counts": dict(action_counts), "auto_assign_counts_by_facet": dict(facet_auto_counts), "coverage": { "baseline_unique_pct": coverage_pct(base_groups, all_tags), "baseline_top15_pct": greedy_top15_pct(base_groups, occ), "projected_unique_pct_with_explicit_facet": coverage_pct(projected_all, all_tags), "projected_top15_pct_with_explicit_facet": greedy_top15_pct(projected_all, occ), "projected_unique_pct_without_explicit_facet": coverage_pct(projected_no_explicit, all_tags), "projected_top15_pct_without_explicit_facet": greedy_top15_pct(projected_no_explicit, occ), }, "outputs": { "assignments_csv": str(OUT_ASSIGN), "summary_json": str(OUT_SUMMARY), }, } with OUT_SUMMARY.open("w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) print("Uncovered before:", len(uncovered)) print("Decisions:", dict(action_counts)) print("Auto by facet:", dict(facet_auto_counts)) print("Coverage:", summary["coverage"]) print("Outputs:", OUT_ASSIGN, OUT_SUMMARY) if __name__ == "__main__": main()