Spaces:
Running on Zero
Running on Zero
| """ | |
| Unit tests for model detection functionality. | |
| Tests the detect_model_type function in src/Core/Models/ModelFactory.py | |
| with various filename patterns and edge cases. | |
| Note: GGUF/FLUX support has been removed. GGUF files now raise ValueError. | |
| """ | |
| import os | |
| import sys | |
| import pytest | |
| from pathlib import Path | |
| from unittest.mock import patch, MagicMock | |
| # Add project root to path | |
| project_root = Path(__file__).resolve().parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from src.Core.Models.ModelFactory import detect_model_type, list_available_models | |
| class TestDetectModelType: | |
| """Test suite for detect_model_type function.""" | |
| # ========================================================================= | |
| # SD1.5 Detection Tests | |
| # ========================================================================= | |
| def test_detect_sd15_from_generic_safetensors(self): | |
| """SD1.5 should be detected for generic .safetensors files.""" | |
| result = detect_model_type("model.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_from_pt_file(self): | |
| """SD1.5 should be detected for .pt files without SDXL marker.""" | |
| result = detect_model_type("dreamshaper_8.pt") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_from_pth_file(self): | |
| """SD1.5 should be detected for .pth files without SDXL marker.""" | |
| result = detect_model_type("anime_model.pth") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_from_dreamshaper(self): | |
| """DreamShaper models should be detected as SD1.5.""" | |
| result = detect_model_type("DreamShaper_8_pruned.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_from_meina(self): | |
| """Meina models should be detected as SD1.5.""" | |
| result = detect_model_type("Meina V10 - baked VAE.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_from_realistic_vision(self): | |
| """Realistic Vision models should be detected as SD1.5.""" | |
| result = detect_model_type("realisticVisionV60.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_with_absolute_path(self): | |
| """Detection should work with absolute paths.""" | |
| # Windows-style path | |
| result = detect_model_type("C:\\Models\\checkpoints\\my_model.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| # Unix-style path | |
| result = detect_model_type("/home/user/models/my_model.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| def test_detect_sd15_with_relative_path(self): | |
| """Detection should work with relative paths.""" | |
| result = detect_model_type("./include/checkpoints/model.safetensors") | |
| assert result == "SD15", f"Expected SD15, got {result}" | |
| # ========================================================================= | |
| # SDXL Detection Tests | |
| # ========================================================================= | |
| def test_detect_sdxl_from_filename_marker(self): | |
| """SDXL should be detected from 'sdxl' in filename.""" | |
| result = detect_model_type("juggernaut_sdxl_v9.safetensors") | |
| assert result == "SDXL", f"Expected SDXL, got {result}" | |
| def test_detect_sdxl_case_insensitive(self): | |
| """SDXL detection should be case-insensitive.""" | |
| test_cases = [ | |
| "SDXL_model.safetensors", | |
| "Sdxl_model.safetensors", | |
| "model_SDXL.safetensors", | |
| "mySdXlModel.safetensors", | |
| ] | |
| for filename in test_cases: | |
| result = detect_model_type(filename) | |
| assert result == "SDXL", f"Expected SDXL for {filename}, got {result}" | |
| def test_detect_sdxl_from_refiner(self): | |
| """SDXL should be detected from 'refiner' in filename.""" | |
| result = detect_model_type("sd_xl_refiner_1.0.safetensors") | |
| assert result == "SDXL", f"Expected SDXL, got {result}" | |
| def test_detect_sdxl_from_hassaku(self): | |
| """SDXL should be detected from 'hassaku' in filename.""" | |
| result = detect_model_type("hassakuXL_v13.safetensors") | |
| assert result == "SDXL", f"Expected SDXL, got {result}" | |
| def test_detect_sdxl_juggernaut(self): | |
| """Juggernaut XL models should be detected as SDXL due to 'juggernaut' indicator.""" | |
| result = detect_model_type("Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors") | |
| assert result == "SDXL", f"Expected SDXL (juggernaut indicator), got {result}" | |
| def test_detect_sdxl_with_path(self): | |
| """SDXL detection works with full paths if basename contains marker.""" | |
| # Note: Detection is on os.path.basename(lp), not full path | |
| result = detect_model_type("/models/checkpoints/sdxl_base_model.safetensors") | |
| assert result == "SDXL", f"Expected SDXL, got {result}" | |
| # Path with sdxl in directory but not in filename defaults to SD15 | |
| result_nomarker = detect_model_type("/models/sdxl/base_model.safetensors") | |
| assert result_nomarker == "SD15", f"Expected SD15 (marker not in basename), got {result_nomarker}" | |
| # ========================================================================= | |
| # GGUF Files - No Longer Supported (Must Raise ValueError) | |
| # ========================================================================= | |
| def test_gguf_files_raise_value_error(self): | |
| """GGUF files should raise ValueError as they're no longer supported.""" | |
| with pytest.raises(ValueError, match="GGUF files not supported"): | |
| detect_model_type("flux1-dev-Q8_0.gguf") | |
| def test_gguf_any_filename_raises_error(self): | |
| """Any .gguf file should raise ValueError.""" | |
| test_cases = [ | |
| "my_flux_model.gguf", | |
| "FLUX_model.gguf", | |
| "Flux_model.gguf", | |
| "model_FLUX.gguf", | |
| "random_model.gguf", | |
| "/models/flux/flux1-dev.gguf", | |
| ] | |
| for filename in test_cases: | |
| with pytest.raises(ValueError, match="GGUF files not supported"): | |
| detect_model_type(filename) | |
| # ========================================================================= | |
| # Edge Cases and Error Handling | |
| # ========================================================================= | |
| def test_detect_none_input(self): | |
| """None input should return SD15 (default).""" | |
| result = detect_model_type(None) | |
| assert result == "SD15", f"Expected SD15 for None input, got {result}" | |
| def test_detect_empty_string(self): | |
| """Empty string should return SD15 (default).""" | |
| result = detect_model_type("") | |
| assert result == "SD15", f"Expected SD15 for empty string, got {result}" | |
| def test_detect_unknown_extension(self): | |
| """Unknown extensions should default to SD15.""" | |
| result = detect_model_type("model.bin") | |
| assert result == "SD15", f"Expected SD15 for .bin file, got {result}" | |
| def test_detect_no_extension(self): | |
| """Files without extension should default to SD15.""" | |
| result = detect_model_type("model_file") | |
| assert result == "SD15", f"Expected SD15 for no extension, got {result}" | |
| def test_detect_preserves_original_path(self): | |
| """Detection should not modify the input path.""" | |
| original_path = "path/to/model.safetensors" | |
| detect_model_type(original_path) | |
| assert original_path == "path/to/model.safetensors" | |
| class TestListAvailableModels: | |
| """Test suite for list_available_models function.""" | |
| def test_list_returns_list(self): | |
| """list_available_models should return a list.""" | |
| result = list_available_models() | |
| assert isinstance(result, list), f"Expected list, got {type(result)}" | |
| def test_list_with_mapping_returns_tuples(self): | |
| """list_available_models(return_mapping=True) should return list of tuples.""" | |
| result = list_available_models(return_mapping=True) | |
| assert isinstance(result, list), f"Expected list, got {type(result)}" | |
| # If non-empty, check tuple format | |
| if result: | |
| assert all( | |
| isinstance(item, tuple) and len(item) == 2 | |
| for item in result | |
| ), "Each item should be a (display_name, full_path) tuple" | |
| def test_list_filters_valid_extensions(self): | |
| """Only valid model extensions should be returned (no .gguf).""" | |
| valid_extensions = (".safetensors", ".pt", ".pth") | |
| result = list_available_models(return_mapping=True) | |
| for display_name, full_path in result: | |
| ext = os.path.splitext(display_name.lower())[1] | |
| assert ext in valid_extensions, ( | |
| f"Invalid extension {ext} in {display_name}" | |
| ) | |
| def test_list_returns_basenames_by_default(self): | |
| """Default return should be basenames only.""" | |
| result = list_available_models(return_mapping=False) | |
| for name in result: | |
| # Should not contain path separators | |
| assert "/" not in name and "\\" not in name, ( | |
| f"Expected basename, got path: {name}" | |
| ) | |
| class TestModelDetectionIntegration: | |
| """Integration tests for model detection with real file patterns.""" | |
| def test_detection_matrix(self, filename, expected): | |
| """Test detection across a matrix of common model filenames.""" | |
| result = detect_model_type(filename) | |
| assert result == expected, f"Expected {expected} for {filename}, got {result}" | |
| def test_gguf_files_raise_error(self, filename): | |
| """All GGUF files should raise ValueError.""" | |
| with pytest.raises(ValueError, match="GGUF files not supported"): | |
| detect_model_type(filename) | |