LightDiffusion-Next / tests /unit /test_model_detection.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
"""
Unit tests for model detection functionality.
Tests the detect_model_type function in src/Core/Models/ModelFactory.py
with various filename patterns and edge cases.
Note: GGUF/FLUX support has been removed. GGUF files now raise ValueError.
"""
import os
import sys
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(project_root))
from src.Core.Models.ModelFactory import detect_model_type, list_available_models
class TestDetectModelType:
"""Test suite for detect_model_type function."""
# =========================================================================
# SD1.5 Detection Tests
# =========================================================================
def test_detect_sd15_from_generic_safetensors(self):
"""SD1.5 should be detected for generic .safetensors files."""
result = detect_model_type("model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_pt_file(self):
"""SD1.5 should be detected for .pt files without SDXL marker."""
result = detect_model_type("dreamshaper_8.pt")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_pth_file(self):
"""SD1.5 should be detected for .pth files without SDXL marker."""
result = detect_model_type("anime_model.pth")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_dreamshaper(self):
"""DreamShaper models should be detected as SD1.5."""
result = detect_model_type("DreamShaper_8_pruned.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_meina(self):
"""Meina models should be detected as SD1.5."""
result = detect_model_type("Meina V10 - baked VAE.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_from_realistic_vision(self):
"""Realistic Vision models should be detected as SD1.5."""
result = detect_model_type("realisticVisionV60.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_with_absolute_path(self):
"""Detection should work with absolute paths."""
# Windows-style path
result = detect_model_type("C:\\Models\\checkpoints\\my_model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
# Unix-style path
result = detect_model_type("/home/user/models/my_model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
def test_detect_sd15_with_relative_path(self):
"""Detection should work with relative paths."""
result = detect_model_type("./include/checkpoints/model.safetensors")
assert result == "SD15", f"Expected SD15, got {result}"
# =========================================================================
# SDXL Detection Tests
# =========================================================================
def test_detect_sdxl_from_filename_marker(self):
"""SDXL should be detected from 'sdxl' in filename."""
result = detect_model_type("juggernaut_sdxl_v9.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_case_insensitive(self):
"""SDXL detection should be case-insensitive."""
test_cases = [
"SDXL_model.safetensors",
"Sdxl_model.safetensors",
"model_SDXL.safetensors",
"mySdXlModel.safetensors",
]
for filename in test_cases:
result = detect_model_type(filename)
assert result == "SDXL", f"Expected SDXL for {filename}, got {result}"
def test_detect_sdxl_from_refiner(self):
"""SDXL should be detected from 'refiner' in filename."""
result = detect_model_type("sd_xl_refiner_1.0.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_from_hassaku(self):
"""SDXL should be detected from 'hassaku' in filename."""
result = detect_model_type("hassakuXL_v13.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
def test_detect_sdxl_juggernaut(self):
"""Juggernaut XL models should be detected as SDXL due to 'juggernaut' indicator."""
result = detect_model_type("Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors")
assert result == "SDXL", f"Expected SDXL (juggernaut indicator), got {result}"
def test_detect_sdxl_with_path(self):
"""SDXL detection works with full paths if basename contains marker."""
# Note: Detection is on os.path.basename(lp), not full path
result = detect_model_type("/models/checkpoints/sdxl_base_model.safetensors")
assert result == "SDXL", f"Expected SDXL, got {result}"
# Path with sdxl in directory but not in filename defaults to SD15
result_nomarker = detect_model_type("/models/sdxl/base_model.safetensors")
assert result_nomarker == "SD15", f"Expected SD15 (marker not in basename), got {result_nomarker}"
# =========================================================================
# GGUF Files - No Longer Supported (Must Raise ValueError)
# =========================================================================
def test_gguf_files_raise_value_error(self):
"""GGUF files should raise ValueError as they're no longer supported."""
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type("flux1-dev-Q8_0.gguf")
def test_gguf_any_filename_raises_error(self):
"""Any .gguf file should raise ValueError."""
test_cases = [
"my_flux_model.gguf",
"FLUX_model.gguf",
"Flux_model.gguf",
"model_FLUX.gguf",
"random_model.gguf",
"/models/flux/flux1-dev.gguf",
]
for filename in test_cases:
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type(filename)
# =========================================================================
# Edge Cases and Error Handling
# =========================================================================
def test_detect_none_input(self):
"""None input should return SD15 (default)."""
result = detect_model_type(None)
assert result == "SD15", f"Expected SD15 for None input, got {result}"
def test_detect_empty_string(self):
"""Empty string should return SD15 (default)."""
result = detect_model_type("")
assert result == "SD15", f"Expected SD15 for empty string, got {result}"
def test_detect_unknown_extension(self):
"""Unknown extensions should default to SD15."""
result = detect_model_type("model.bin")
assert result == "SD15", f"Expected SD15 for .bin file, got {result}"
def test_detect_no_extension(self):
"""Files without extension should default to SD15."""
result = detect_model_type("model_file")
assert result == "SD15", f"Expected SD15 for no extension, got {result}"
def test_detect_preserves_original_path(self):
"""Detection should not modify the input path."""
original_path = "path/to/model.safetensors"
detect_model_type(original_path)
assert original_path == "path/to/model.safetensors"
class TestListAvailableModels:
"""Test suite for list_available_models function."""
def test_list_returns_list(self):
"""list_available_models should return a list."""
result = list_available_models()
assert isinstance(result, list), f"Expected list, got {type(result)}"
def test_list_with_mapping_returns_tuples(self):
"""list_available_models(return_mapping=True) should return list of tuples."""
result = list_available_models(return_mapping=True)
assert isinstance(result, list), f"Expected list, got {type(result)}"
# If non-empty, check tuple format
if result:
assert all(
isinstance(item, tuple) and len(item) == 2
for item in result
), "Each item should be a (display_name, full_path) tuple"
def test_list_filters_valid_extensions(self):
"""Only valid model extensions should be returned (no .gguf)."""
valid_extensions = (".safetensors", ".pt", ".pth")
result = list_available_models(return_mapping=True)
for display_name, full_path in result:
ext = os.path.splitext(display_name.lower())[1]
assert ext in valid_extensions, (
f"Invalid extension {ext} in {display_name}"
)
def test_list_returns_basenames_by_default(self):
"""Default return should be basenames only."""
result = list_available_models(return_mapping=False)
for name in result:
# Should not contain path separators
assert "/" not in name and "\\" not in name, (
f"Expected basename, got path: {name}"
)
class TestModelDetectionIntegration:
"""Integration tests for model detection with real file patterns."""
@pytest.mark.parametrize("filename,expected", [
# SD1.5 models
("DreamShaper_8_pruned.safetensors", "SD15"),
("v1-5-pruned.safetensors", "SD15"),
("anythingV5.safetensors", "SD15"),
("deliberate_v3.safetensors", "SD15"),
("realisticVision.safetensors", "SD15"),
# SDXL models (contain 'sdxl', 'refiner', 'hassaku', 'juggernaut', or 'xl')
("sd_xl_base_1.0.safetensors", "SDXL"), # contains 'xl'
("Juggernaut-XL_v9.safetensors", "SDXL"), # contains 'juggernaut' and 'xl'
("sdxl_vae.safetensors", "SDXL"),
("hassakuXLv13.safetensors", "SDXL"),
("SDXL_refiner_1.0.safetensors", "SDXL"),
])
def test_detection_matrix(self, filename, expected):
"""Test detection across a matrix of common model filenames."""
result = detect_model_type(filename)
assert result == expected, f"Expected {expected} for {filename}, got {result}"
@pytest.mark.parametrize("filename", [
"flux1-dev-Q8_0.gguf",
"flux-schnell.gguf",
"any_model.gguf",
])
def test_gguf_files_raise_error(self, filename):
"""All GGUF files should raise ValueError."""
with pytest.raises(ValueError, match="GGUF files not supported"):
detect_model_type(filename)