File size: 2,245 Bytes
14967a0 b680bf1 14967a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import os
import unittest
from pathlib import Path
from unittest.mock import patch
class TestResolveModelFolder(unittest.TestCase):
@patch.dict(os.environ, {"MODEL_FOLDER": "/tmp/custom_models"})
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
def test_model_folder_env_override(self, get_model_dir_mocked):
import app as app_module
result = app_module.resolve_model_folder()
get_model_dir_mocked.assert_not_called()
self.assertEqual(result, Path("/tmp/custom_models"))
@patch.dict(
os.environ,
{"MODEL_VARIANT": "sam2.1_hiera_tiny_uint8"},
clear=False,
)
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
def test_model_variant_env_uses_registry(self, get_model_dir_mocked):
get_model_dir_mocked.return_value = Path("/mock/tiny")
# Ensure MODEL_FOLDER is absent
os.environ.pop("MODEL_FOLDER", None)
import app as app_module
get_model_dir_mocked.reset_mock()
result = app_module.resolve_model_folder()
get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_tiny_uint8")
self.assertEqual(result, Path("/mock/tiny"))
@patch.dict(os.environ, {}, clear=False)
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
def test_default_variant_when_no_env_vars(self, get_model_dir_mocked):
get_model_dir_mocked.return_value = Path("/mock/default")
os.environ.pop("MODEL_FOLDER", None)
os.environ.pop("MODEL_VARIANT", None)
import app as app_module
get_model_dir_mocked.reset_mock()
result = app_module.resolve_model_folder()
get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_base_plus_uint8")
self.assertEqual(result, Path("/mock/default"))
@patch.dict(os.environ, {"MODEL_FOLDER": ""})
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
def test_empty_string_model_folder(self, get_model_dir_mocked):
import app as app_module
result = app_module.resolve_model_folder()
get_model_dir_mocked.assert_not_called()
self.assertEqual(result, Path(""))
if __name__ == "__main__":
unittest.main()
|