Spaces:
Running on Zero
Running on Zero
File size: 10,920 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | """
Unit tests for model detection functionality.
Tests the detect_model_type function in src/Core/Models/ModelFactory.py
with various filename patterns and edge cases.
Note: GGUF/FLUX support has been removed. GGUF files now raise ValueError.
"""
import os
import sys
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
from src.Core.Models.ModelFactory import detect_model_type, list_available_models
class TestDetectModelType:
"""Test suite for detect_model_type function."""
# =========================================================================
# SD1.5 Detection Tests
# =========================================================================
def test_detect_sd15_from_generic_safetensors(self):
"""SD1.5 should be detected for generic .safetensors files."""
result = detect_model_type("model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_pt_file(self):
"""SD1.5 should be detected for .pt files without SDXL marker."""
result = detect_model_type("dreamshaper_8.pt")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_pth_file(self):
"""SD1.5 should be detected for .pth files without SDXL marker."""
result = detect_model_type("anime_model.pth")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_dreamshaper(self):
"""DreamShaper models should be detected as SD1.5."""
result = detect_model_type("DreamShaper_8_pruned.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_meina(self):
"""Meina models should be detected as SD1.5."""
result = detect_model_type("Meina V10 - baked VAE.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_realistic_vision(self):
"""Realistic Vision models should be detected as SD1.5."""
result = detect_model_type("realisticVisionV60.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_with_absolute_path(self):
"""Detection should work with absolute paths."""
# Windows-style path
result = detect_model_type("C:\\Models\\checkpoints\\my_model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
# Unix-style path
result = detect_model_type("/home/user/models/my_model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_with_relative_path(self):
"""Detection should work with relative paths."""
result = detect_model_type("./include/checkpoints/model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
# =========================================================================
# SDXL Detection Tests
# =========================================================================
def test_detect_sdxl_from_filename_marker(self):
"""SDXL should be detected from 'sdxl' in filename."""
result = detect_model_type("juggernaut_sdxl_v9.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_case_insensitive(self):
"""SDXL detection should be case-insensitive."""
test_cases = [
"SDXL_model.safetensors",
"Sdxl_model.safetensors",
"model_SDXL.safetensors",
"mySdXlModel.safetensors",
]
for filename in test_cases:
result = detect_model_type(filename)
assert result == "SDXL", f"Expected SDXL for {filename}, got {result}"
def test_detect_sdxl_from_refiner(self):
"""SDXL should be detected from 'refiner' in filename."""
result = detect_model_type("sd_xl_refiner_1.0.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_from_hassaku(self):
"""SDXL should be detected from 'hassaku' in filename."""
result = detect_model_type("hassakuXL_v13.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_juggernaut(self):
"""Juggernaut XL models should be detected as SDXL due to 'juggernaut' indicator."""
result = detect_model_type("Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors")
assert result == "SDXL", f"Expected SDXL (juggernaut indicator), got {result}"
def test_detect_sdxl_with_path(self):
"""SDXL detection works with full paths if basename contains marker."""
# Note: Detection is on os.path.basename(lp), not full path
result = detect_model_type("/models/checkpoints/sdxl_base_model.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
# Path with sdxl in directory but not in filename defaults to SD15
result_nomarker = detect_model_type("/models/sdxl/base_model.safetensors")
assert result_nomarker == "SD15", f"Expected SD15 (marker not in basename), got {result_nomarker}"
# =========================================================================
# GGUF Files - No Longer Supported (Must Raise ValueError)
# =========================================================================
def test_gguf_files_raise_value_error(self):
"""GGUF files should raise ValueError as they're no longer supported."""
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type("flux1-dev-Q8_0.gguf")
def test_gguf_any_filename_raises_error(self):
"""Any .gguf file should raise ValueError."""
test_cases = [
"my_flux_model.gguf",
"FLUX_model.gguf",
"Flux_model.gguf",
"model_FLUX.gguf",
"random_model.gguf",
"/models/flux/flux1-dev.gguf",
]
for filename in test_cases:
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type(filename)
# =========================================================================
# Edge Cases and Error Handling
# =========================================================================
def test_detect_none_input(self):
"""None input should return SD15 (default)."""
result = detect_model_type(None)
assert result == "SD15", f"Expected SD15 for None input, got {result}"
def test_detect_empty_string(self):
"""Empty string should return SD15 (default)."""
result = detect_model_type("")
assert result == "SD15", f"Expected SD15 for empty string, got {result}"
def test_detect_unknown_extension(self):
"""Unknown extensions should default to SD15."""
result = detect_model_type("model.bin")
assert result == "SD15", f"Expected SD15 for .bin file, got {result}"
def test_detect_no_extension(self):
"""Files without extension should default to SD15."""
result = detect_model_type("model_file")
assert result == "SD15", f"Expected SD15 for no extension, got {result}"
def test_detect_preserves_original_path(self):
"""Detection should not modify the input path."""
original_path = "path/to/model.safetensors"
detect_model_type(original_path)
assert original_path == "path/to/model.safetensors"
class TestListAvailableModels:
"""Test suite for list_available_models function."""
def test_list_returns_list(self):
"""list_available_models should return a list."""
result = list_available_models()
assert isinstance(result, list), f"Expected list, got {type(result)}"
def test_list_with_mapping_returns_tuples(self):
"""list_available_models(return_mapping=True) should return list of tuples."""
result = list_available_models(return_mapping=True)
assert isinstance(result, list), f"Expected list, got {type(result)}"
# If non-empty, check tuple format
if result:
assert all(
isinstance(item, tuple) and len(item) == 2
for item in result
), "Each item should be a (display_name, full_path) tuple"
def test_list_filters_valid_extensions(self):
"""Only valid model extensions should be returned (no .gguf)."""
valid_extensions = (".safetensors", ".pt", ".pth")
result = list_available_models(return_mapping=True)
for display_name, full_path in result:
ext = os.path.splitext(display_name.lower())[1]
assert ext in valid_extensions, (
f"Invalid extension {ext} in {display_name}"
)
def test_list_returns_basenames_by_default(self):
"""Default return should be basenames only."""
result = list_available_models(return_mapping=False)
for name in result:
# Should not contain path separators
assert "/" not in name and "\\" not in name, (
f"Expected basename, got path: {name}"
)
class TestModelDetectionIntegration:
"""Integration tests for model detection with real file patterns."""
@pytest.mark.parametrize("filename,expected", [
# SD1.5 models
("DreamShaper_8_pruned.safetensors", "SD15"),
("v1-5-pruned.safetensors", "SD15"),
("anythingV5.safetensors", "SD15"),
("deliberate_v3.safetensors", "SD15"),
("realisticVision.safetensors", "SD15"),
# SDXL models (contain 'sdxl', 'refiner', 'hassaku', 'juggernaut', or 'xl')
("sd_xl_base_1.0.safetensors", "SDXL"), # contains 'xl'
("Juggernaut-XL_v9.safetensors", "SDXL"), # contains 'juggernaut' and 'xl'
("sdxl_vae.safetensors", "SDXL"),
("hassakuXLv13.safetensors", "SDXL"),
("SDXL_refiner_1.0.safetensors", "SDXL"),
])
def test_detection_matrix(self, filename, expected):
"""Test detection across a matrix of common model filenames."""
result = detect_model_type(filename)
assert result == expected, f"Expected {expected} for {filename}, got {result}"
@pytest.mark.parametrize("filename", [
"flux1-dev-Q8_0.gguf",
"flux-schnell.gguf",
"any_model.gguf",
])
def test_gguf_files_raise_error(self, filename):
"""All GGUF files should raise ValueError."""
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type(filename)
|