| """ |
| Tests for model management functionality. |
| """ |
|
|
| import pytest |
| import tempfile |
| from pathlib import Path |
| from unittest.mock import Mock, patch, MagicMock |
| import json |
|
|
| from models import ( |
| ModelRegistry, |
| ModelInfo, |
| ModelStatus, |
| ModelTask, |
| ModelFramework, |
| ModelDownloader, |
| ModelLoader, |
| ModelOptimizer |
| ) |
|
|
|
|
| class TestModelRegistry: |
| """Test model registry functionality.""" |
| |
| @pytest.fixture |
| def registry(self): |
| """Create a test registry.""" |
| temp_dir = tempfile.mkdtemp() |
| return ModelRegistry(models_dir=Path(temp_dir)) |
| |
| def test_registry_initialization(self, registry): |
| """Test registry initialization.""" |
| assert registry is not None |
| assert len(registry.models) > 0 |
| assert registry.models_dir.exists() |
| |
| def test_register_model(self, registry): |
| """Test registering a new model.""" |
| model = ModelInfo( |
| model_id="test-model", |
| name="Test Model", |
| version="1.0", |
| task=ModelTask.SEGMENTATION, |
| framework=ModelFramework.PYTORCH, |
| url="http://example.com/model.pth", |
| filename="test.pth", |
| file_size=1000000 |
| ) |
| |
| success = registry.register_model(model) |
| assert success == True |
| assert "test-model" in registry.models |
| |
| def test_get_model(self, registry): |
| """Test getting a model by ID.""" |
| model = registry.get_model("rmbg-1.4") |
| assert model is not None |
| assert model.model_id == "rmbg-1.4" |
| assert model.task == ModelTask.SEGMENTATION |
| |
| def test_list_models_by_task(self, registry): |
| """Test listing models by task.""" |
| segmentation_models = registry.list_models(task=ModelTask.SEGMENTATION) |
| assert len(segmentation_models) > 0 |
| assert all(m.task == ModelTask.SEGMENTATION for m in segmentation_models) |
| |
| def test_list_models_by_framework(self, registry): |
| """Test listing models by framework.""" |
| pytorch_models = registry.list_models(framework=ModelFramework.PYTORCH) |
| onnx_models = registry.list_models(framework=ModelFramework.ONNX) |
| |
| assert all(m.framework == ModelFramework.PYTORCH for m in pytorch_models) |
| assert all(m.framework == ModelFramework.ONNX for m in onnx_models) |
| |
| def test_get_best_model(self, registry): |
| """Test getting best model for a task.""" |
| |
| best_accuracy = registry.get_best_model( |
| ModelTask.SEGMENTATION, |
| prefer_speed=False |
| ) |
| assert best_accuracy is not None |
| |
| |
| best_speed = registry.get_best_model( |
| ModelTask.SEGMENTATION, |
| prefer_speed=True |
| ) |
| assert best_speed is not None |
| |
| def test_update_model_usage(self, registry): |
| """Test updating model usage statistics.""" |
| model_id = "rmbg-1.4" |
| initial_count = registry.models[model_id].use_count |
| |
| registry.update_model_usage(model_id) |
| |
| assert registry.models[model_id].use_count == initial_count + 1 |
| assert registry.models[model_id].last_used is not None |
| |
| def test_get_total_size(self, registry): |
| """Test calculating total model size.""" |
| total_size = registry.get_total_size() |
| assert total_size > 0 |
| |
| |
| available_size = registry.get_total_size(status=ModelStatus.AVAILABLE) |
| assert available_size == 0 |
| |
| def test_export_registry(self, registry, temp_dir): |
| """Test exporting registry to file.""" |
| export_path = temp_dir / "registry_export.json" |
| registry.export_registry(export_path) |
| |
| assert export_path.exists() |
| |
| with open(export_path) as f: |
| data = json.load(f) |
| assert "models" in data |
| assert len(data["models"]) > 0 |
|
|
|
|
| class TestModelDownloader: |
| """Test model downloading functionality.""" |
| |
| @pytest.fixture |
| def downloader(self, mock_registry): |
| """Create a test downloader.""" |
| return ModelDownloader(mock_registry) |
| |
| @patch('requests.get') |
| def test_download_model(self, mock_get, downloader): |
| """Test downloading a model.""" |
| |
| mock_response = MagicMock() |
| mock_response.headers = {'content-length': '1000000'} |
| mock_response.iter_content = MagicMock( |
| return_value=[b'data' * 1000] |
| ) |
| mock_response.raise_for_status = MagicMock() |
| mock_get.return_value = mock_response |
| |
| |
| success = downloader.download_model("test-model", force=True) |
| |
| assert mock_get.called |
| |
| |
| def test_download_progress_tracking(self, downloader): |
| """Test download progress tracking.""" |
| progress_values = [] |
| |
| def progress_callback(progress): |
| progress_values.append(progress.progress) |
| |
| |
| with patch.object(downloader, '_download_model_task', return_value=True): |
| downloader.download_model( |
| "test-model", |
| progress_callback=progress_callback |
| ) |
| |
| assert "test-model" in downloader.downloads |
| |
| def test_cancel_download(self, downloader): |
| """Test cancelling a download.""" |
| |
| downloader.downloads["test-model"] = Mock() |
| downloader._stop_events["test-model"] = Mock() |
| |
| success = downloader.cancel_download("test-model") |
| |
| assert success == True |
| assert downloader._stop_events["test-model"].set.called |
| |
| def test_download_with_resume(self, downloader, temp_dir): |
| """Test download with resume support.""" |
| |
| partial_file = temp_dir / "test.pth.part" |
| partial_file.write_bytes(b"partial_data") |
| |
| |
| assert partial_file.exists() |
| assert partial_file.stat().st_size > 0 |
|
|
|
|
| class TestModelLoader: |
| """Test model loading functionality.""" |
| |
| @pytest.fixture |
| def loader(self, mock_registry): |
| """Create a test loader.""" |
| return ModelLoader(mock_registry, device='cpu') |
| |
| def test_loader_initialization(self, loader): |
| """Test loader initialization.""" |
| assert loader is not None |
| assert loader.device == 'cpu' |
| assert loader.max_memory_bytes > 0 |
| |
| @patch('torch.load') |
| def test_load_pytorch_model(self, mock_torch_load, loader): |
| """Test loading a PyTorch model.""" |
| mock_model = MagicMock() |
| mock_torch_load.return_value = mock_model |
| |
| |
| model_info = ModelInfo( |
| model_id="test-pytorch", |
| name="Test PyTorch Model", |
| version="1.0", |
| task=ModelTask.SEGMENTATION, |
| framework=ModelFramework.PYTORCH, |
| url="", |
| filename="model.pth", |
| local_path="/tmp/model.pth", |
| status=ModelStatus.AVAILABLE |
| ) |
| |
| loader.registry.get_model = Mock(return_value=model_info) |
| |
| with patch.object(Path, 'exists', return_value=True): |
| loaded = loader.load_model("test-pytorch") |
| |
| |
| assert mock_torch_load.called |
| |
| def test_memory_management(self, loader): |
| """Test memory management during model loading.""" |
| |
| for i in range(5): |
| loader.loaded_models[f"model_{i}"] = Mock( |
| memory_usage=100 * 1024 * 1024 |
| ) |
| |
| loader.current_memory_usage = 500 * 1024 * 1024 |
| |
| |
| loader._free_memory(200 * 1024 * 1024) |
| |
| |
| assert len(loader.loaded_models) < 5 |
| |
| def test_unload_model(self, loader): |
| """Test unloading a model.""" |
| |
| loader.loaded_models["test"] = Mock( |
| model=Mock(), |
| memory_usage=100 * 1024 * 1024 |
| ) |
| loader.current_memory_usage = 100 * 1024 * 1024 |
| |
| success = loader.unload_model("test") |
| |
| assert success == True |
| assert "test" not in loader.loaded_models |
| assert loader.current_memory_usage == 0 |
| |
| def test_get_memory_usage(self, loader): |
| """Test getting memory usage statistics.""" |
| |
| loader.loaded_models["model1"] = Mock(memory_usage=100 * 1024 * 1024) |
| loader.loaded_models["model2"] = Mock(memory_usage=200 * 1024 * 1024) |
| loader.current_memory_usage = 300 * 1024 * 1024 |
| |
| usage = loader.get_memory_usage() |
| |
| assert usage["current_usage_mb"] == 300 |
| assert usage["loaded_models"] == 2 |
| assert "model1" in usage["models"] |
| assert "model2" in usage["models"] |
|
|
|
|
| class TestModelOptimizer: |
| """Test model optimization functionality.""" |
| |
| @pytest.fixture |
| def optimizer(self, mock_registry): |
| """Create a test optimizer.""" |
| loader = ModelLoader(mock_registry, device='cpu') |
| return ModelOptimizer(loader) |
| |
| @patch('torch.quantization.quantize_dynamic') |
| def test_quantize_pytorch_model(self, mock_quantize, optimizer): |
| """Test PyTorch model quantization.""" |
| |
| mock_model = MagicMock() |
| mock_quantize.return_value = mock_model |
| |
| loaded = Mock( |
| model_id="test", |
| model=mock_model, |
| framework=ModelFramework.PYTORCH, |
| metadata={'input_size': (1, 3, 512, 512)} |
| ) |
| |
| with patch.object(optimizer, '_get_model_size', return_value=1000000): |
| with patch.object(optimizer, '_benchmark_model', return_value=0.1): |
| result = optimizer._quantize_pytorch( |
| loaded, |
| Path("/tmp"), |
| "dynamic" |
| ) |
| |
| assert mock_quantize.called |
| |
| |
| def test_optimization_result(self, optimizer): |
| """Test optimization result structure.""" |
| from models.optimizer import OptimizationResult |
| |
| result = OptimizationResult( |
| original_size_mb=100, |
| optimized_size_mb=25, |
| compression_ratio=4.0, |
| original_speed_ms=100, |
| optimized_speed_ms=50, |
| speedup=2.0, |
| accuracy_loss=0.01, |
| optimization_time=10.0, |
| output_path="/tmp/optimized.pth" |
| ) |
| |
| assert result.compression_ratio == 4.0 |
| assert result.speedup == 2.0 |
| assert result.accuracy_loss == 0.01 |
|
|
|
|
| class TestModelIntegration: |
| """Integration tests for model management.""" |
| |
| @pytest.mark.integration |
| @pytest.mark.slow |
| def test_model_registry_persistence(self, temp_dir): |
| """Test registry persistence across instances.""" |
| |
| registry1 = ModelRegistry(models_dir=temp_dir) |
| |
| test_model = ModelInfo( |
| model_id="persistence-test", |
| name="Persistence Test", |
| version="1.0", |
| task=ModelTask.SEGMENTATION, |
| framework=ModelFramework.PYTORCH, |
| url="http://example.com/model.pth", |
| filename="persist.pth" |
| ) |
| |
| registry1.register_model(test_model) |
| |
| |
| registry2 = ModelRegistry(models_dir=temp_dir) |
| |
| |
| loaded_model = registry2.get_model("persistence-test") |
| assert loaded_model is not None |
| assert loaded_model.name == "Persistence Test" |
| |
| @pytest.mark.integration |
| def test_model_manager_workflow(self): |
| """Test complete model manager workflow.""" |
| from models import create_model_manager |
| |
| manager = create_model_manager() |
| |
| |
| stats = manager.get_stats() |
| assert "registry" in stats |
| assert stats["registry"]["total_models"] > 0 |
| |
| |
| with patch.object(manager.loader, 'load_model', return_value=Mock()): |
| benchmarks = manager.benchmark() |
| |
| assert isinstance(benchmarks, dict) |