Prompt_Squirrel_RAG / scripts /analyze_guided_facet_assignment.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Guided facet assignment for uncovered tags using vectors + lightweight rules.
Purpose:
- Assign high-frequency uncovered tags into semantically useful facets.
- Avoid naive free clustering by using seeded centroids + lexical constraints.
- Keep output compact: exactly two overwrite-in-place files.
Outputs (overwritten each run):
- data/analysis/guided_facet_assignments.csv
- data/analysis/guided_facet_summary.json
"""
from __future__ import annotations
import csv
import json
import re
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple
import numpy as np
from psq_rag.retrieval.state import get_tfidf_tag_vectors
REPO = Path(__file__).resolve().parents[1]
COUNTS_CSV = REPO / "fluffyrock_3m.csv"
SAMPLE_JSONL = REPO / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
WIKI_GROUPS_JSON = REPO / "data" / "tag_groups.json"
REGISTRY_CSV = REPO / "data" / "category_registry.csv"
PROPOSAL_CSV = REPO / "data" / "analysis" / "category_expansion_proposal.csv"
OUT_ASSIGN = REPO / "data" / "analysis" / "guided_facet_assignments.csv"
OUT_SUMMARY = REPO / "data" / "analysis" / "guided_facet_summary.json"
MIN_COUNT = 200
FACETS = {
"species_taxonomy": {
"seeds": {
"canid", "canis", "felid", "felis", "equid", "leporid", "domestic_dog", "domestic_cat",
"wolf", "fox", "bird", "bear", "unicorn", "dragon", "reptile", "bovid", "pony", "horse",
},
"patterns": [r"canid|canis|felid|felis|equid|leporid|domestic_|wolf|fox|bird|bear|unicorn|dragon|reptile|bovid|pony|horse|pantherine"],
"sim_min": 0.74,
"margin_min": 0.03,
},
"character_traits": {
"seeds": {
"young", "cub", "vein", "muscular", "slightly_chubby", "overweight", "curvy_figure",
"thick_thighs", "wide_hips", "huge_breasts", "huge_butt", "abs", "pecs",
},
"patterns": [r"young|cub|vein|muscular|chubby|overweight|curvy|thigh|hips|abs|pecs|belly|cleavage"],
"sim_min": 0.73,
"margin_min": 0.03,
},
"clothing_coverage": {
"seeds": {"topless", "bottomless", "barefoot", "panties", "thigh_highs", "stockings", "clothed"},
"patterns": [r"topless|bottomless|barefoot|panties|thigh_highs|stockings|underwear|nude"],
"sim_min": 0.70,
"margin_min": 0.02,
},
"symbol_text_misc": {
"seeds": {"<3", "text", "symbol", "emblem", "logo"},
"patterns": [r"^<3$", r"text|symbol|logo|emblem|heart"],
"sim_min": 0.0,
"margin_min": 0.0,
},
"fluids_explicit_sensitive": {
"seeds": {"bodily_fluids", "saliva", "sweat", "dripping", "cum", "nude", "nipples"},
"patterns": [r"fluid|saliva|sweat|drip|cum|nude|nipple|areola|bodily_fluids"],
"sim_min": 0.68,
"margin_min": 0.01,
},
}
# Light lexical boosts; still gated by thresholds for most facets.
LEXICAL_BOOST = 0.08
def load_counts(path: Path) -> Dict[str, int]:
counts: Dict[str, int] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
try:
counts[row[0]] = int(row[2]) if row[2] else 0
except ValueError:
counts[row[0]] = 0
return counts
def load_sample_tag_occurrences(path: Path, counts: Dict[str, int], min_count: int) -> Counter:
occ = Counter()
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
raw = obj.get("tags_ground_truth_categorized", "")
if not raw:
continue
try:
categorized = json.loads(raw)
except Exception:
continue
tags: Set[str] = set()
if isinstance(categorized, dict):
for vals in categorized.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str) and counts.get(t, 0) >= min_count:
tags.add(t)
occ.update(tags)
return occ
def load_base_groups() -> Dict[str, Set[str]]:
with WIKI_GROUPS_JSON.open("r", encoding="utf-8") as f:
wiki = json.load(f)
groups = {k: set(v) for k, v in wiki.items() if isinstance(v, list)}
with REGISTRY_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}:
continue
c = (row.get("category_name") or "").strip()
t = (row.get("tag") or "").strip()
if c and t:
groups.setdefault(f"cat:{c}", set()).add(t)
with PROPOSAL_CSV.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if row.get("proposed_action") not in {"new_category", "merge_existing"}:
continue
tgt = (row.get("target_category") or "").strip()
tag = (row.get("tag") or "").strip()
if tgt and tag and tgt != "none":
groups.setdefault(f"cat:{tgt}", set()).add(tag)
return groups
def build_centroids(tag_to_row: Dict[str, int], vectors_norm: np.ndarray) -> Dict[str, np.ndarray]:
centroids: Dict[str, np.ndarray] = {}
for facet, cfg in FACETS.items():
seed_idxs = [tag_to_row[t] for t in cfg["seeds"] if t in tag_to_row]
if len(seed_idxs) < 2:
continue
mat = vectors_norm[seed_idxs]
c = mat.mean(axis=0)
n = np.linalg.norm(c)
if n == 0:
continue
centroids[facet] = c / n
return centroids
def lexical_match_score(tag: str, facet: str) -> float:
patterns = FACETS[facet]["patterns"]
for p in patterns:
if re.search(p, tag):
return LEXICAL_BOOST
return 0.0
def decision_for(tag: str, facet: str, sim: float, margin: float, lexical: float) -> str:
# Symbol/text facet is mostly lexical by design.
if facet == "symbol_text_misc":
if lexical > 0.0 or re.search(r"[^a-z0-9_()/-]", tag):
return "auto_assign"
return "review"
cfg = FACETS[facet]
score = sim + lexical
if score >= cfg["sim_min"] and margin >= cfg["margin_min"]:
return "auto_assign"
return "review"
def coverage_pct(groups: Dict[str, Set[str]], tags: Set[str]) -> float:
covered = sum(1 for t in tags if any(t in g for g in groups.values()))
return round((covered / len(tags) * 100.0), 2) if tags else 0.0
def greedy_top15_pct(groups: Dict[str, Set[str]], occ: Counter) -> float:
uncovered = Counter(occ)
total = sum(occ.values())
covered = 0
chosen: Set[str] = set()
for _ in range(15):
best_g = None
best_gain = 0
best_new = []
for g, tags in groups.items():
if g in chosen:
continue
gain = 0
new_tags = []
for t in tags:
c = uncovered.get(t, 0)
if c > 0:
gain += c
new_tags.append(t)
if gain > best_gain:
best_g = g
best_gain = gain
best_new = new_tags
if not best_g or best_gain <= 0:
break
chosen.add(best_g)
for t in best_new:
uncovered[t] = 0
covered += best_gain
return round((covered / total) * 100.0, 2) if total else 0.0
def main() -> None:
counts = load_counts(COUNTS_CSV)
occ = load_sample_tag_occurrences(SAMPLE_JSONL, counts, MIN_COUNT)
all_tags = set(occ.keys())
base_groups = load_base_groups()
covered_base = {t for t in all_tags if any(t in g for g in base_groups.values())}
uncovered = sorted(all_tags - covered_base, key=lambda t: (counts.get(t, 0), occ[t]), reverse=True)
vectors = get_tfidf_tag_vectors()
vectors_norm = vectors["reduced_matrix_norm"]
tag_to_row = vectors["tag_to_row_index"]
centroids = build_centroids(tag_to_row, vectors_norm)
facet_names = sorted(centroids.keys())
C = np.stack([centroids[f] for f in facet_names], axis=0)
rows: List[Dict[str, str]] = []
action_counts = Counter()
facet_auto_counts = Counter()
for tag in uncovered:
if tag not in tag_to_row:
continue
sims = C @ vectors_norm[tag_to_row[tag]]
order = np.argsort(sims)[::-1]
i1 = int(order[0])
i2 = int(order[1]) if sims.size > 1 else i1
best_facet = facet_names[i1]
best_sim = float(sims[i1])
second_facet = facet_names[i2]
second_sim = float(sims[i2])
margin = best_sim - second_sim
lex = lexical_match_score(tag, best_facet)
score = best_sim + lex
decision = decision_for(tag, best_facet, best_sim, margin, lex)
action_counts[decision] += 1
if decision == "auto_assign":
facet_auto_counts[best_facet] += 1
rows.append(
{
"tag": tag,
"fluffyrock_count": str(counts.get(tag, 0)),
"sample_occurrences": str(occ[tag]),
"best_facet": best_facet,
"best_sim": f"{best_sim:.6f}",
"lexical_boost": f"{lex:.2f}",
"score": f"{score:.6f}",
"second_facet": second_facet,
"second_sim": f"{second_sim:.6f}",
"margin": f"{margin:.6f}",
"decision": decision,
}
)
rows.sort(key=lambda r: (r["decision"] != "auto_assign", -int(r["fluffyrock_count"])))
OUT_ASSIGN.parent.mkdir(parents=True, exist_ok=True)
with OUT_ASSIGN.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"fluffyrock_count",
"sample_occurrences",
"best_facet",
"best_sim",
"lexical_boost",
"score",
"second_facet",
"second_sim",
"margin",
"decision",
],
)
writer.writeheader()
writer.writerows(rows)
# Coverage projection, with and without explicit-sensitive facet enabled.
projected_all = {k: set(v) for k, v in base_groups.items()}
projected_no_explicit = {k: set(v) for k, v in base_groups.items()}
facet_to_group = {
"species_taxonomy": "cat:species_specific",
"character_traits": "cat:character_traits",
"clothing_coverage": "cat:clothing_detail",
"symbol_text_misc": "cat:miscellaneous",
"fluids_explicit_sensitive": "cat:explicit_sensitive",
}
for r in rows:
if r["decision"] != "auto_assign":
continue
tag = r["tag"]
facet = r["best_facet"]
group_key = facet_to_group[facet]
projected_all.setdefault(group_key, set()).add(tag)
if facet != "fluids_explicit_sensitive":
projected_no_explicit.setdefault(group_key, set()).add(tag)
summary = {
"min_count": MIN_COUNT,
"n_unique_tags_considered": len(all_tags),
"n_uncovered_before_guided_facets": len(uncovered),
"facet_names": facet_names,
"decision_counts": dict(action_counts),
"auto_assign_counts_by_facet": dict(facet_auto_counts),
"coverage": {
"baseline_unique_pct": coverage_pct(base_groups, all_tags),
"baseline_top15_pct": greedy_top15_pct(base_groups, occ),
"projected_unique_pct_with_explicit_facet": coverage_pct(projected_all, all_tags),
"projected_top15_pct_with_explicit_facet": greedy_top15_pct(projected_all, occ),
"projected_unique_pct_without_explicit_facet": coverage_pct(projected_no_explicit, all_tags),
"projected_top15_pct_without_explicit_facet": greedy_top15_pct(projected_no_explicit, occ),
},
"outputs": {
"assignments_csv": str(OUT_ASSIGN),
"summary_json": str(OUT_SUMMARY),
},
}
with OUT_SUMMARY.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print("Uncovered before:", len(uncovered))
print("Decisions:", dict(action_counts))
print("Auto by facet:", dict(facet_auto_counts))
print("Coverage:", summary["coverage"])
print("Outputs:", OUT_ASSIGN, OUT_SUMMARY)
if __name__ == "__main__":
main()