Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Test script for category parser only (no TF-IDF dependencies). | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Add parent directory to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from psq_rag.tagging.category_parser import parse_checklist, should_skip_category | |
| def test_parse_checklist(): | |
| """Test parsing the checklist file.""" | |
| print("=" * 80) | |
| print("Testing checklist parsing...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| if not checklist_path.exists(): | |
| print(f"ERROR: Checklist not found at {checklist_path}") | |
| return False | |
| categories = parse_checklist(checklist_path) | |
| print(f"\nParsed {len(categories)} categories:\n") | |
| # Group by tier | |
| by_tier = {} | |
| for cat_name, category in categories.items(): | |
| tier = category.tier.name | |
| if tier not in by_tier: | |
| by_tier[tier] = [] | |
| by_tier[tier].append((cat_name, category)) | |
| for tier_name in sorted(by_tier.keys()): | |
| print(f"\n{'='*60}") | |
| print(f"{tier_name} Tier") | |
| print('='*60) | |
| for cat_name, category in by_tier[tier_name]: | |
| print(f"\n Category: {cat_name}") | |
| print(f" Display: {category.display_name}") | |
| print(f" Constraint: {category.constraint.value}") | |
| print(f" Tags ({len(category.tags)}): {', '.join(category.tags[:8])}") | |
| if len(category.tags) > 8: | |
| print(f" ... and {len(category.tags) - 8} more") | |
| if category.depends_on: | |
| print(f" Depends on: {category.depends_on}") | |
| if category.skip_if: | |
| print(f" Skip if: {category.skip_if}") | |
| return True | |
| def test_skip_logic(): | |
| """Test the skip logic for zero_pictured.""" | |
| print("\n" + "=" * 80) | |
| print("Testing skip logic...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| categories = parse_checklist(checklist_path) | |
| # Test 1: zero_pictured should skip character categories | |
| selected_tags = {'zero_pictured', 'forest', 'outside'} | |
| character_cats = ['body_type', 'species', 'gender', 'clothing'] | |
| print(f"\nTest 1: Selected tags = {selected_tags}") | |
| print("Should skip character/appearance categories:") | |
| all_correct = True | |
| for cat_name in character_cats: | |
| if cat_name in categories: | |
| should_skip = should_skip_category( | |
| categories[cat_name], | |
| selected_tags, | |
| categories | |
| ) | |
| status = "β SKIP" if should_skip else "β KEEP" | |
| print(f" {cat_name}: {status}") | |
| if not should_skip: | |
| all_correct = False | |
| # Test 2: solo should NOT skip character categories | |
| selected_tags = {'solo', 'anthro', 'male'} | |
| print(f"\nTest 2: Selected tags = {selected_tags}") | |
| print("Should NOT skip character categories:") | |
| for cat_name in character_cats: | |
| if cat_name in categories: | |
| should_skip = should_skip_category( | |
| categories[cat_name], | |
| selected_tags, | |
| categories | |
| ) | |
| status = "β KEEP" if not should_skip else "β SKIP" | |
| print(f" {cat_name}: {status}") | |
| if should_skip: | |
| all_correct = False | |
| return all_correct | |
| def test_tag_extraction(): | |
| """Test that tags are extracted correctly from descriptions.""" | |
| print("\n" + "=" * 80) | |
| print("Testing tag extraction...") | |
| print("=" * 80) | |
| checklist_path = Path(__file__).parent.parent / "tagging_checklist.txt" | |
| categories = parse_checklist(checklist_path) | |
| # Check specific categories we care about | |
| test_cases = { | |
| 'count': ['solo', 'duo', 'trio', 'group', 'zero_pictured'], | |
| 'rating': ['safe', 'questionable', 'explicit'], # These might not parse correctly | |
| 'body_type': ['anthro', 'feral', 'humanoid', 'taur'], | |
| 'species': ['human', 'canine', 'feline', 'equine'], | |
| 'gender': ['male', 'female', 'intersex', 'ambiguous_gender'], | |
| 'clothing': ['fully_clothed', 'partially_clothed', 'nude'], | |
| 'location': ['inside', 'outside', 'bedroom', 'kitchen', 'forest'], | |
| } | |
| all_correct = True | |
| for cat_name, expected_tags in test_cases.items(): | |
| if cat_name not in categories: | |
| print(f"\nβ Category '{cat_name}' not found!") | |
| all_correct = False | |
| continue | |
| category = categories[cat_name] | |
| found_tags = set(category.tags) | |
| expected_set = set(expected_tags) | |
| missing = expected_set - found_tags | |
| extra = found_tags - expected_set | |
| if missing: | |
| print(f"\n{cat_name}:") | |
| print(f" β Missing expected tags: {missing}") | |
| all_correct = False | |
| else: | |
| print(f"\nβ {cat_name}: All expected tags found") | |
| print(f" Tags: {', '.join(sorted(category.tags))}") | |
| return all_correct | |
| if __name__ == "__main__": | |
| success = True | |
| success &= test_parse_checklist() | |
| success &= test_tag_extraction() | |
| success &= test_skip_logic() | |
| print("\n" + "=" * 80) | |
| if success: | |
| print("β All tests passed!") | |
| else: | |
| print("β Some tests failed") | |
| print("=" * 80) | |
| sys.exit(0 if success else 1) | |