Prompt_Squirrel_RAG / scripts /test_parser_only.py
Claude
Add tag categorization pipeline for e621 checklist
50a9851
#!/usr/bin/env python3
"""
Test script for category parser only (no TF-IDF dependencies).
"""
from pathlib import Path
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from psq_rag.tagging.category_parser import parse_checklist, should_skip_category
def test_parse_checklist():
"""Test parsing the checklist file."""
print("=" * 80)
print("Testing checklist parsing...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
if not checklist_path.exists():
print(f"ERROR: Checklist not found at {checklist_path}")
return False
categories = parse_checklist(checklist_path)
print(f"\nParsed {len(categories)} categories:\n")
# Group by tier
by_tier = {}
for cat_name, category in categories.items():
tier = category.tier.name
if tier not in by_tier:
by_tier[tier] = []
by_tier[tier].append((cat_name, category))
for tier_name in sorted(by_tier.keys()):
print(f"\n{'='*60}")
print(f"{tier_name} Tier")
print('='*60)
for cat_name, category in by_tier[tier_name]:
print(f"\n Category: {cat_name}")
print(f" Display: {category.display_name}")
print(f" Constraint: {category.constraint.value}")
print(f" Tags ({len(category.tags)}): {', '.join(category.tags[:8])}")
if len(category.tags) > 8:
print(f" ... and {len(category.tags) - 8} more")
if category.depends_on:
print(f" Depends on: {category.depends_on}")
if category.skip_if:
print(f" Skip if: {category.skip_if}")
return True
def test_skip_logic():
"""Test the skip logic for zero_pictured."""
print("\n" + "=" * 80)
print("Testing skip logic...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
categories = parse_checklist(checklist_path)
# Test 1: zero_pictured should skip character categories
selected_tags = {'zero_pictured', 'forest', 'outside'}
character_cats = ['body_type', 'species', 'gender', 'clothing']
print(f"\nTest 1: Selected tags = {selected_tags}")
print("Should skip character/appearance categories:")
all_correct = True
for cat_name in character_cats:
if cat_name in categories:
should_skip = should_skip_category(
categories[cat_name],
selected_tags,
categories
)
status = "βœ“ SKIP" if should_skip else "βœ— KEEP"
print(f" {cat_name}: {status}")
if not should_skip:
all_correct = False
# Test 2: solo should NOT skip character categories
selected_tags = {'solo', 'anthro', 'male'}
print(f"\nTest 2: Selected tags = {selected_tags}")
print("Should NOT skip character categories:")
for cat_name in character_cats:
if cat_name in categories:
should_skip = should_skip_category(
categories[cat_name],
selected_tags,
categories
)
status = "βœ“ KEEP" if not should_skip else "βœ— SKIP"
print(f" {cat_name}: {status}")
if should_skip:
all_correct = False
return all_correct
def test_tag_extraction():
"""Test that tags are extracted correctly from descriptions."""
print("\n" + "=" * 80)
print("Testing tag extraction...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
categories = parse_checklist(checklist_path)
# Check specific categories we care about
test_cases = {
'count': ['solo', 'duo', 'trio', 'group', 'zero_pictured'],
'rating': ['safe', 'questionable', 'explicit'], # These might not parse correctly
'body_type': ['anthro', 'feral', 'humanoid', 'taur'],
'species': ['human', 'canine', 'feline', 'equine'],
'gender': ['male', 'female', 'intersex', 'ambiguous_gender'],
'clothing': ['fully_clothed', 'partially_clothed', 'nude'],
'location': ['inside', 'outside', 'bedroom', 'kitchen', 'forest'],
}
all_correct = True
for cat_name, expected_tags in test_cases.items():
if cat_name not in categories:
print(f"\nβœ— Category '{cat_name}' not found!")
all_correct = False
continue
category = categories[cat_name]
found_tags = set(category.tags)
expected_set = set(expected_tags)
missing = expected_set - found_tags
extra = found_tags - expected_set
if missing:
print(f"\n{cat_name}:")
print(f" βœ— Missing expected tags: {missing}")
all_correct = False
else:
print(f"\nβœ“ {cat_name}: All expected tags found")
print(f" Tags: {', '.join(sorted(category.tags))}")
return all_correct
if __name__ == "__main__":
success = True
success &= test_parse_checklist()
success &= test_tag_extraction()
success &= test_skip_logic()
print("\n" + "=" * 80)
if success:
print("βœ“ All tests passed!")
else:
print("βœ— Some tests failed")
print("=" * 80)
sys.exit(0 if success else 1)