Spaces:
Running
Running
File size: 5,008 Bytes
50a9851 5188881 50a9851 5188881 50a9851 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | #!/usr/bin/env python3
"""
Test script for categorized tag suggestions.
"""
from pathlib import Path
import sys
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from psq_rag.tagging.category_parser import parse_checklist
from psq_rag.tagging.categorized_suggestions import (
generate_categorized_suggestions,
load_categories,
)
def test_parse_checklist():
"""Test parsing the checklist file."""
print("=" * 80)
print("Testing checklist parsing...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
if not checklist_path.exists():
print(f"ERROR: Checklist not found at {checklist_path}")
return False
categories = parse_checklist(checklist_path)
print(f"\nParsed {len(categories)} categories:")
for cat_name, category in categories.items():
print(f"\n {cat_name}:")
print(f" Display: {category.display_name}")
print(f" Tier: {category.tier.name}")
print(f" Constraint: {category.constraint.value}")
print(f" Tags: {len(category.tags)} tags")
print(f" Sample tags: {category.tags[:5]}")
if category.depends_on:
print(f" Depends on: {category.depends_on}")
if category.skip_if:
print(f" Skip if: {category.skip_if}")
return True
def test_categorized_suggestions():
"""Test generating categorized suggestions."""
print("\n" + "=" * 80)
print("Testing categorized suggestions...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
# Example: User prompt resulted in these LLM-selected tags
selected_tags = [
"anthro",
"canine",
"male",
"solo",
"forest",
"standing",
]
print(f"\nSelected tags: {', '.join(selected_tags)}")
print("\nGenerating categorized suggestions...")
try:
categorized = generate_categorized_suggestions(
selected_tags,
allow_nsfw_tags=False,
top_n_per_category=5,
top_n_other=10,
checklist_path=checklist_path,
)
print("\nCategory summary:")
for cat_name, cat_sugg in categorized.by_category.items():
if cat_sugg.suggestions:
preview = ", ".join(f"{tag}:{score:.3f}" for tag, score in cat_sugg.suggestions[:3])
print(f" {cat_name}: {preview}")
if categorized.other_suggestions:
other_preview = ", ".join(f"{tag}:{score:.3f}" for tag, score in categorized.other_suggestions[:5])
print(f" other: {other_preview}")
return True
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
return False
def test_zero_pictured():
"""Test that character categories are skipped for zero_pictured."""
print("\n" + "=" * 80)
print("Testing zero_pictured dependency logic...")
print("=" * 80)
checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt"
selected_tags = [
"zero_pictured",
"forest",
"outside",
]
print(f"\nSelected tags: {', '.join(selected_tags)}")
print("(Should skip character/appearance categories)")
try:
categorized = generate_categorized_suggestions(
selected_tags,
allow_nsfw_tags=False,
top_n_per_category=5,
top_n_other=10,
checklist_path=checklist_path,
)
print("\nCategories with suggestions:")
for cat_name, cat_sugg in categorized.by_category.items():
if cat_sugg.suggestions or cat_sugg.already_selected:
print(f" {cat_name}: {len(cat_sugg.suggestions)} suggestions")
# Check that character categories are empty
character_cats = ['body_type', 'species', 'gender', 'clothing']
all_skipped = True
for cat in character_cats:
if cat in categorized.by_category:
if categorized.by_category[cat].suggestions:
print(f" WARNING: {cat} should have been skipped!")
all_skipped = False
if all_skipped:
print("\n✓ All character categories correctly skipped!")
return all_skipped
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = True
success &= test_parse_checklist()
success &= test_categorized_suggestions()
success &= test_zero_pictured()
print("\n" + "=" * 80)
if success:
print("✓ All tests passed!")
else:
print("✗ Some tests failed")
print("=" * 80)
sys.exit(0 if success else 1)
|