Prompt_Squirrel_RAG / scripts /propose_category_expansion.py
Food Desert
Consolidate probe configs and eval artifacts on main
6e50f4d
"""Create a concrete category-expansion proposal and estimate coverage impact.
Inputs:
- data/analysis/tag_group_uncovered_after_topn_combined200.csv
- data/category_registry.csv
- data/tag_groups.json
- fluffyrock_3m.csv
- data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl
Outputs:
- data/analysis/category_expansion_proposal.csv
- data/analysis/category_expansion_coverage.json
"""
from __future__ import annotations
import csv
import json
from collections import Counter
from pathlib import Path
from typing import Dict, List, Set, Tuple
REPO_ROOT = Path(__file__).resolve().parents[1]
UNCOVERED_PATH = REPO_ROOT / "data" / "analysis" / "tag_group_uncovered_after_topn_combined200.csv"
REGISTRY_PATH = REPO_ROOT / "data" / "category_registry.csv"
TAG_GROUPS_PATH = REPO_ROOT / "data" / "tag_groups.json"
FLUFFYROCK_PATH = REPO_ROOT / "fluffyrock_3m.csv"
SAMPLE_PATH = REPO_ROOT / "data" / "eval_samples" / "e621_sfw_sample_1000_seed123_buffer10000.jsonl"
OUT_PROPOSAL = REPO_ROOT / "data" / "analysis" / "category_expansion_proposal.csv"
OUT_COVERAGE = REPO_ROOT / "data" / "analysis" / "category_expansion_coverage.json"
MIN_COUNT = 200
TOP_N_GROUPS = 15
MAX_STEPS = 25
def _load_counts(path: Path) -> Dict[str, int]:
out: Dict[str, int] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
try:
out[row[0]] = int(row[2]) if row[2] else 0
except ValueError:
out[row[0]] = 0
return out
def _load_sample_tags(path: Path, counts: Dict[str, int], min_count: int) -> List[Set[str]]:
rows: List[Set[str]] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
obj = json.loads(line)
raw = obj.get("tags_ground_truth_categorized", "")
if not raw:
continue
try:
d = json.loads(raw)
except Exception:
continue
tags: Set[str] = set()
if isinstance(d, dict):
for vals in d.values():
if isinstance(vals, list):
for t in vals:
if isinstance(t, str) and counts.get(t, 0) >= min_count:
tags.add(t)
if tags:
rows.append(tags)
return rows
def _load_wiki_groups(path: Path) -> Dict[str, Set[str]]:
with path.open("r", encoding="utf-8") as f:
raw = json.load(f)
return {k: set(v) for k, v in raw.items() if isinstance(v, list)}
def _load_category_groups(path: Path) -> Dict[str, Set[str]]:
groups: Dict[str, Set[str]] = {}
with path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
if (row.get("category_enabled") or "").strip() not in {"1", "true", "True"}:
continue
c = (row.get("category_name") or "").strip()
t = (row.get("tag") or "").strip()
if c and t:
groups.setdefault(f"cat:{c}", set()).add(t)
return groups
def _greedy(groups: Dict[str, Set[str]], tag_occ: Counter, max_steps: int) -> Tuple[List[Dict[str, object]], Set[str]]:
uncovered = Counter(tag_occ)
chosen: Set[str] = set()
selected: List[Dict[str, object]] = []
total = sum(tag_occ.values())
covered = 0
for step in range(1, max_steps + 1):
best, best_gain = None, 0
best_new: Set[str] = set()
for g, tags in groups.items():
if g in chosen:
continue
gain = 0
new_tags: Set[str] = set()
for t in tags:
c = uncovered.get(t, 0)
if c > 0:
gain += c
new_tags.add(t)
if gain > best_gain:
best, best_gain, best_new = g, gain, new_tags
if not best or best_gain <= 0:
break
chosen.add(best)
for t in best_new:
uncovered[t] = 0
covered += best_gain
selected.append(
{
"step": step,
"group": best,
"gain_occurrences": best_gain,
"cumulative_covered_occurrences": covered,
"cumulative_covered_pct": round(covered / total * 100.0, 2) if total else 0.0,
}
)
return selected, chosen
def _recommend(tag: str) -> Tuple[str, str, str]:
if tag in {"solo", "duo", "trio", "group", "solo_focus"}:
return "new_category", "character_count", "mutually exclusive count-like options"
if "/" in tag or tag in {"romantic_couple", "interspecies"}:
return "new_category", "relationship_pairing", "relationship/pairing semantics shown best together"
if tag in {"muscular", "muscular_anthro", "slightly_chubby", "overweight", "thick_thighs", "wide_hips", "big_butt"}:
return "new_category", "body_build", "body-shape alternatives useful side-by-side"
if tag in {
"canid", "canis", "felid", "felis", "equid", "domestic_dog", "domestic_cat",
"wolf", "fox", "dragon", "reptile", "leporid", "rabbit", "horse", "pony",
"pantherine", "bovid", "animal_humanoid", "hybrid",
}:
return "new_category", "species_specific", "taxonomy/detail species cluster"
if any(tag.startswith(c) for c in ("red_", "blue_", "green_", "yellow_", "black_", "white_", "brown_", "grey_", "purple_", "orange_", "teal_")):
return "merge_existing", "color_markings", "color-region/attribute tag"
if "hair" in tag:
return "merge_existing", "hair", "hair style/color detail"
if tag in {"nipples", "areola", "butt", "navel", "feet", "belly", "abs", "pecs", "teeth", "tongue", "tail", "horn", "wings", "claws", "fangs", "fingers", "toes"}:
return "merge_existing", "anatomy_features", "anatomy/body-part trait"
if tag in {"half-closed_eyes", "eyelashes", "eyebrows"}:
return "merge_existing", "expression_detail", "eye/expression detail"
if tag in {"bodily_fluids", "saliva", "sweat", "nude", "bound", "bottomless", "hyper"}:
return "deprioritize", "none", "sensitive/noisy for default non-explicit-centric UX"
if tag in {"pose", "holding_object", "rear_view", "licking", "biped"}:
return "merge_existing", "pose_action_detail", "pose/action detail"
if tag in {"eyewear", "jewelry", "glasses", "hat", "gloves", "panties"}:
return "merge_existing", "clothing_detail", "attire/accessory detail"
if tag in {"fur", "tuft", "feathers", "not_furry", "anthrofied"}:
return "merge_existing", "fur_style", "fur/covering style detail"
return "needs_review", "uncategorized_review", "high-frequency uncovered tag needing manual judgment"
def main() -> None:
counts = _load_counts(FLUFFYROCK_PATH)
sample_rows = _load_sample_tags(SAMPLE_PATH, counts, MIN_COUNT)
wiki_groups = _load_wiki_groups(TAG_GROUPS_PATH)
category_groups = _load_category_groups(REGISTRY_PATH)
base_groups = {**wiki_groups, **category_groups}
tag_occ = Counter()
for tags in sample_rows:
tag_occ.update(tags)
# Baseline coverage with current wiki+category groups.
covered_any_base = {t for t in tag_occ if any(t in g for g in base_groups.values())}
greedy_base, _ = _greedy(base_groups, tag_occ, MAX_STEPS)
# Build proposal from uncovered-after-topN file (already ranked by frequency).
proposal_rows: List[Dict[str, str]] = []
art_group = wiki_groups.get("art", set())
with UNCOVERED_PATH.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = row["tag"]
action, target, why = _recommend(tag)
proposal_rows.append(
{
"tag": tag,
"fluffyrock_count": row.get("fluffyrock_count", ""),
"sample_occurrences": row.get("sample_occurrences", ""),
"proposed_action": action,
"target_category": target,
"in_art_tag_group": "1" if tag in art_group else "0",
"reason": why,
}
)
OUT_PROPOSAL.parent.mkdir(parents=True, exist_ok=True)
with OUT_PROPOSAL.open("w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"tag",
"fluffyrock_count",
"sample_occurrences",
"proposed_action",
"target_category",
"in_art_tag_group",
"reason",
],
)
writer.writeheader()
writer.writerows(proposal_rows)
# Apply recommendations to projection groups.
projected_groups: Dict[str, Set[str]] = {k: set(v) for k, v in base_groups.items()}
for row in proposal_rows:
action = row["proposed_action"]
if action not in {"new_category", "merge_existing"}:
continue
target = row["target_category"].strip()
if not target or target == "none":
continue
key = f"cat:{target}"
projected_groups.setdefault(key, set()).add(row["tag"])
covered_any_projected = {t for t in tag_occ if any(t in g for g in projected_groups.values())}
greedy_projected, _ = _greedy(projected_groups, tag_occ, MAX_STEPS)
topn = TOP_N_GROUPS
base_topn_pct = greedy_base[topn - 1]["cumulative_covered_pct"] if len(greedy_base) >= topn else (greedy_base[-1]["cumulative_covered_pct"] if greedy_base else 0.0)
proj_topn_pct = greedy_projected[topn - 1]["cumulative_covered_pct"] if len(greedy_projected) >= topn else (greedy_projected[-1]["cumulative_covered_pct"] if greedy_projected else 0.0)
summary = {
"inputs": {
"min_count": MIN_COUNT,
"top_n_groups": TOP_N_GROUPS,
"sample_file": str(SAMPLE_PATH),
"proposal_source_uncovered": str(UNCOVERED_PATH),
},
"proposal_counts": dict(Counter(r["proposed_action"] for r in proposal_rows)),
"art_tags_in_proposal": [r for r in proposal_rows if r["in_art_tag_group"] == "1"],
"coverage_baseline": {
"n_groups": len(base_groups),
"unique_covered_pct": round((len(covered_any_base) / len(tag_occ) * 100.0), 2) if tag_occ else 0.0,
"top15_greedy_cumulative_pct": base_topn_pct,
"top15_groups": [x["group"] for x in greedy_base[:TOP_N_GROUPS]],
},
"coverage_projected_with_proposal": {
"n_groups": len(projected_groups),
"unique_covered_pct": round((len(covered_any_projected) / len(tag_occ) * 100.0), 2) if tag_occ else 0.0,
"top15_greedy_cumulative_pct": proj_topn_pct,
"top15_groups": [x["group"] for x in greedy_projected[:TOP_N_GROUPS]],
},
"outputs": {
"proposal_csv": str(OUT_PROPOSAL),
"coverage_json": str(OUT_COVERAGE),
},
}
with OUT_COVERAGE.open("w", encoding="utf-8") as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
print("Proposal rows:", len(proposal_rows))
print("Proposal action counts:", summary["proposal_counts"])
print("Baseline unique covered %:", summary["coverage_baseline"]["unique_covered_pct"])
print("Projected unique covered %:", summary["coverage_projected_with_proposal"]["unique_covered_pct"])
print("Baseline top15 greedy %:", summary["coverage_baseline"]["top15_greedy_cumulative_pct"])
print("Projected top15 greedy %:", summary["coverage_projected_with_proposal"]["top15_greedy_cumulative_pct"])
print("Outputs:", OUT_PROPOSAL, OUT_COVERAGE)
if __name__ == "__main__":
main()