Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Test script for categorized tag suggestions. | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Add parent directory to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from psq_rag.tagging.category_parser import parse_checklist | |
| from psq_rag.tagging.categorized_suggestions import ( | |
| generate_categorized_suggestions, | |
| load_categories, | |
| ) | |
| def test_parse_checklist(): | |
| """Test parsing the checklist file.""" | |
| print("=" * 80) | |
| print("Testing checklist parsing...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| if not checklist_path.exists(): | |
| print(f"ERROR: Checklist not found at {checklist_path}") | |
| return False | |
| categories = parse_checklist(checklist_path) | |
| print(f"\nParsed {len(categories)} categories:") | |
| for cat_name, category in categories.items(): | |
| print(f"\n {cat_name}:") | |
| print(f" Display: {category.display_name}") | |
| print(f" Tier: {category.tier.name}") | |
| print(f" Constraint: {category.constraint.value}") | |
| print(f" Tags: {len(category.tags)} tags") | |
| print(f" Sample tags: {category.tags[:5]}") | |
| if category.depends_on: | |
| print(f" Depends on: {category.depends_on}") | |
| if category.skip_if: | |
| print(f" Skip if: {category.skip_if}") | |
| return True | |
| def test_categorized_suggestions(): | |
| """Test generating categorized suggestions.""" | |
| print("\n" + "=" * 80) | |
| print("Testing categorized suggestions...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| # Example: User prompt resulted in these LLM-selected tags | |
| selected_tags = [ | |
| "anthro", | |
| "canine", | |
| "male", | |
| "solo", | |
| "forest", | |
| "standing", | |
| ] | |
| print(f"\nSelected tags: {', '.join(selected_tags)}") | |
| print("\nGenerating categorized suggestions...") | |
| try: | |
| categorized = generate_categorized_suggestions( | |
| selected_tags, | |
| allow_nsfw_tags=False, | |
| top_n_per_category=5, | |
| top_n_other=10, | |
| checklist_path=checklist_path, | |
| ) | |
| print("\nCategory summary:") | |
| for cat_name, cat_sugg in categorized.by_category.items(): | |
| if cat_sugg.suggestions: | |
| preview = ", ".join(f"{tag}:{score:.3f}" for tag, score in cat_sugg.suggestions[:3]) | |
| print(f" {cat_name}: {preview}") | |
| if categorized.other_suggestions: | |
| other_preview = ", ".join(f"{tag}:{score:.3f}" for tag, score in categorized.other_suggestions[:5]) | |
| print(f" other: {other_preview}") | |
| return True | |
| except Exception as e: | |
| print(f"ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def test_zero_pictured(): | |
| """Test that character categories are skipped for zero_pictured.""" | |
| print("\n" + "=" * 80) | |
| print("Testing zero_pictured dependency logic...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| selected_tags = [ | |
| "zero_pictured", | |
| "forest", | |
| "outside", | |
| ] | |
| print(f"\nSelected tags: {', '.join(selected_tags)}") | |
| print("(Should skip character/appearance categories)") | |
| try: | |
| categorized = generate_categorized_suggestions( | |
| selected_tags, | |
| allow_nsfw_tags=False, | |
| top_n_per_category=5, | |
| top_n_other=10, | |
| checklist_path=checklist_path, | |
| ) | |
| print("\nCategories with suggestions:") | |
| for cat_name, cat_sugg in categorized.by_category.items(): | |
| if cat_sugg.suggestions or cat_sugg.already_selected: | |
| print(f" {cat_name}: {len(cat_sugg.suggestions)} suggestions") | |
| # Check that character categories are empty | |
| character_cats = ['body_type', 'species', 'gender', 'clothing'] | |
| all_skipped = True | |
| for cat in character_cats: | |
| if cat in categorized.by_category: | |
| if categorized.by_category[cat].suggestions: | |
| print(f" WARNING: {cat} should have been skipped!") | |
| all_skipped = False | |
| if all_skipped: | |
| print("\n✓ All character categories correctly skipped!") | |
| return all_skipped | |
| except Exception as e: | |
| print(f"ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| if __name__ == "__main__": | |
| success = True | |
| success &= test_parse_checklist() | |
| success &= test_categorized_suggestions() | |
| success &= test_zero_pictured() | |
| print("\n" + "=" * 80) | |
| if success: | |
| print("✓ All tests passed!") | |
| else: | |
| print("✗ Some tests failed") | |
| print("=" * 80) | |
| sys.exit(0 if success else 1) | |