| | import pytest |
| | import aiohttp |
| | from aiohttp import ClientResponse |
| | import itertools |
| | import os |
| | from unittest.mock import AsyncMock, patch, MagicMock |
| | from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename |
| |
|
| | class AsyncIteratorMock: |
| | """ |
| | A mock class that simulates an asynchronous iterator. |
| | This is used to mimic the behavior of aiohttp's content iterator. |
| | """ |
| | def __init__(self, seq): |
| | |
| | self.iter = iter(seq) |
| |
|
| | def __aiter__(self): |
| | |
| | return self |
| |
|
| | async def __anext__(self): |
| | |
| | try: |
| | return next(self.iter) |
| | except StopIteration: |
| | |
| | raise StopAsyncIteration |
| |
|
| | class ContentMock: |
| | """ |
| | A mock class that simulates the content attribute of an aiohttp ClientResponse. |
| | This class provides the iter_chunked method which returns an async iterator of chunks. |
| | """ |
| | def __init__(self, chunks): |
| | |
| | self.chunks = chunks |
| |
|
| | def iter_chunked(self, chunk_size): |
| | |
| | |
| | return AsyncIteratorMock(self.chunks) |
| |
|
| | @pytest.mark.asyncio |
| | async def test_download_model_success(): |
| | mock_response = AsyncMock(spec=aiohttp.ClientResponse) |
| | mock_response.status = 200 |
| | mock_response.headers = {'Content-Length': '1000'} |
| | |
| | chunks = [b'a' * 500, b'b' * 300, b'c' * 200] |
| | mock_response.content = ContentMock(chunks) |
| |
|
| | mock_make_request = AsyncMock(return_value=mock_response) |
| | mock_progress_callback = AsyncMock() |
| |
|
| | |
| | mock_open = MagicMock() |
| | mock_file = MagicMock() |
| | mock_open.return_value.__enter__.return_value = mock_file |
| | time_values = itertools.count(0, 0.1) |
| |
|
| | with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ |
| | patch('model_filemanager.check_file_exists', return_value=None), \ |
| | patch('builtins.open', mock_open), \ |
| | patch('time.time', side_effect=time_values): |
| |
|
| | result = await download_model( |
| | mock_make_request, |
| | 'model.sft', |
| | 'http://example.com/model.sft', |
| | 'checkpoints', |
| | mock_progress_callback |
| | ) |
| |
|
| | |
| | assert isinstance(result, DownloadModelStatus) |
| | assert result.message == 'Successfully downloaded model.sft' |
| | assert result.status == 'completed' |
| | assert result.already_existed is False |
| |
|
| | |
| | assert mock_progress_callback.call_count >= 3 |
| | |
| | |
| | mock_progress_callback.assert_any_call( |
| | 'checkpoints/model.sft', |
| | DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) |
| | ) |
| |
|
| | |
| | mock_progress_callback.assert_any_call( |
| | 'checkpoints/model.sft', |
| | DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) |
| | ) |
| |
|
| | |
| | mock_file.write.assert_any_call(b'a' * 500) |
| | mock_file.write.assert_any_call(b'b' * 300) |
| | mock_file.write.assert_any_call(b'c' * 200) |
| |
|
| | |
| | mock_make_request.assert_called_once_with('http://example.com/model.sft') |
| |
|
| | @pytest.mark.asyncio |
| | async def test_download_model_url_request_failure(): |
| | |
| | mock_response = AsyncMock(spec=ClientResponse) |
| | mock_response.status = 404 |
| | mock_get = AsyncMock(return_value=mock_response) |
| | mock_progress_callback = AsyncMock() |
| |
|
| | |
| | with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): |
| | |
| | with patch('model_filemanager.check_file_exists', return_value=None): |
| | |
| | result = await download_model( |
| | mock_get, |
| | 'model.safetensors', |
| | 'http://example.com/model.safetensors', |
| | 'mock_directory', |
| | mock_progress_callback |
| | ) |
| |
|
| | |
| | assert isinstance(result, DownloadModelStatus) |
| | assert result.status == 'error' |
| | assert result.message == 'Failed to download model.safetensors. Status code: 404' |
| | assert result.already_existed is False |
| |
|
| | |
| | mock_progress_callback.assert_any_call( |
| | 'mock_directory/model.safetensors', |
| | DownloadModelStatus( |
| | status=DownloadStatusType.PENDING, |
| | progress_percentage=0, |
| | message='Starting download of model.safetensors', |
| | already_existed=False |
| | ) |
| | ) |
| | mock_progress_callback.assert_called_with( |
| | 'mock_directory/model.safetensors', |
| | DownloadModelStatus( |
| | status=DownloadStatusType.ERROR, |
| | progress_percentage=0, |
| | message='Failed to download model.safetensors. Status code: 404', |
| | already_existed=False |
| | ) |
| | ) |
| |
|
| | |
| | mock_get.assert_called_once_with('http://example.com/model.safetensors') |
| |
|
| | @pytest.mark.asyncio |
| | async def test_download_model_invalid_model_subdirectory(): |
| | |
| | mock_make_request = AsyncMock() |
| | mock_progress_callback = AsyncMock() |
| |
|
| | |
| | result = await download_model( |
| | mock_make_request, |
| | 'model.sft', |
| | 'http://example.com/model.sft', |
| | '../bad_path', |
| | mock_progress_callback |
| | ) |
| |
|
| | |
| | assert isinstance(result, DownloadModelStatus) |
| | assert result.message == 'Invalid model subdirectory' |
| | assert result.status == 'error' |
| | assert result.already_existed is False |
| |
|
| |
|
| | |
| | def test_create_model_path(tmp_path, monkeypatch): |
| | mock_models_dir = tmp_path / "models" |
| | monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) |
| | |
| | model_name = "test_model.sft" |
| | model_directory = "test_dir" |
| | |
| | file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) |
| | |
| | assert file_path == str(mock_models_dir / model_directory / model_name) |
| | assert relative_path == f"{model_directory}/{model_name}" |
| | assert os.path.exists(os.path.dirname(file_path)) |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_check_file_exists_when_file_exists(tmp_path): |
| | file_path = tmp_path / "existing_model.sft" |
| | file_path.touch() |
| | |
| | mock_callback = AsyncMock() |
| | |
| | result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") |
| | |
| | assert result is not None |
| | assert result.status == "completed" |
| | assert result.message == "existing_model.sft already exists" |
| | assert result.already_existed is True |
| | |
| | mock_callback.assert_called_once_with( |
| | "test/existing_model.sft", |
| | DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) |
| | ) |
| |
|
| | @pytest.mark.asyncio |
| | async def test_check_file_exists_when_file_does_not_exist(tmp_path): |
| | file_path = tmp_path / "non_existing_model.sft" |
| | |
| | mock_callback = AsyncMock() |
| | |
| | result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") |
| | |
| | assert result is None |
| | mock_callback.assert_not_called() |
| |
|
| | @pytest.mark.asyncio |
| | async def test_track_download_progress_no_content_length(): |
| | mock_response = AsyncMock(spec=aiohttp.ClientResponse) |
| | mock_response.headers = {} |
| | mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) |
| |
|
| | mock_callback = AsyncMock() |
| | mock_open = MagicMock(return_value=MagicMock()) |
| |
|
| | with patch('builtins.open', mock_open): |
| | result = await track_download_progress( |
| | mock_response, '/mock/path/model.sft', 'model.sft', |
| | mock_callback, 'models/model.sft', interval=0.1 |
| | ) |
| |
|
| | assert result.status == "completed" |
| | |
| | mock_callback.assert_any_call( |
| | 'models/model.sft', |
| | DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) |
| | ) |
| |
|
| | @pytest.mark.asyncio |
| | async def test_track_download_progress_interval(): |
| | mock_response = AsyncMock(spec=aiohttp.ClientResponse) |
| | mock_response.headers = {'Content-Length': '1000'} |
| | mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) |
| |
|
| | mock_callback = AsyncMock() |
| | mock_open = MagicMock(return_value=MagicMock()) |
| |
|
| | |
| | mock_time = MagicMock() |
| | mock_time.side_effect = [i * 0.5 for i in range(30)] |
| |
|
| | with patch('builtins.open', mock_open), \ |
| | patch('time.time', mock_time): |
| | await track_download_progress( |
| | mock_response, '/mock/path/model.sft', 'model.sft', |
| | mock_callback, 'models/model.sft', interval=1.0 |
| | ) |
| |
|
| | |
| | print(f"mock_callback was called {mock_callback.call_count} times") |
| | for i, call in enumerate(mock_callback.call_args_list): |
| | args, kwargs = call |
| | print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") |
| |
|
| | |
| | assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" |
| |
|
| | |
| | first_call = mock_callback.call_args_list[0] |
| | assert first_call[0][1].status == "in_progress" |
| | |
| | assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%" |
| |
|
| | last_call = mock_callback.call_args_list[-1] |
| | assert last_call[0][1].status == "completed" |
| | assert last_call[0][1].progress_percentage == 100 |
| |
|
| | def test_valid_subdirectory(): |
| | assert validate_model_subdirectory("valid-model123") is True |
| |
|
| | def test_subdirectory_too_long(): |
| | assert validate_model_subdirectory("a" * 51) is False |
| |
|
| | def test_subdirectory_with_double_dots(): |
| | assert validate_model_subdirectory("model/../unsafe") is False |
| |
|
| | def test_subdirectory_with_slash(): |
| | assert validate_model_subdirectory("model/unsafe") is False |
| |
|
| | def test_subdirectory_with_special_characters(): |
| | assert validate_model_subdirectory("model@unsafe") is False |
| |
|
| | def test_subdirectory_with_underscore_and_dash(): |
| | assert validate_model_subdirectory("valid_model-name") is True |
| |
|
| | def test_empty_subdirectory(): |
| | assert validate_model_subdirectory("") is False |
| |
|
| | @pytest.mark.parametrize("filename, expected", [ |
| | ("valid_model.safetensors", True), |
| | ("valid_model.sft", True), |
| | ("valid model.safetensors", True), |
| | ("UPPERCASE_MODEL.SAFETENSORS", True), |
| | ("model_with.multiple.dots.pt", False), |
| | ("", False), |
| | ("../../../etc/passwd", False), |
| | ("/etc/passwd", False), |
| | ("\\windows\\system32\\config\\sam", False), |
| | (".hidden_file.pt", False), |
| | ("invalid<char>.ckpt", False), |
| | ("invalid?.ckpt", False), |
| | ("very" * 100 + ".safetensors", False), |
| | ("\nmodel_with_newline.pt", False), |
| | ("model_with_emoji😊.pt", False), |
| | ]) |
| | def test_validate_filename(filename, expected): |
| | assert validate_filename(filename) == expected |
| |
|