alessandro trinca tornidor commited on
Commit ·
b680bf1
1
Parent(s): abeae27
test: fix env leak in test_resolve_model_folder, add MODEL_FOLDER tests
Browse filesWrap test_default_variant_when_no_env_vars with @patch .dict to
auto-restore os.environ on failure. Add _models_available() support
for MODEL_FOLDER env override with two new test cases.
- tests/test_app.py +24 -0
- tests/test_resolve_model_folder.py +1 -0
tests/test_app.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
import unittest
|
|
|
|
| 6 |
from unittest.mock import patch
|
| 7 |
|
| 8 |
import pytest
|
|
@@ -78,6 +79,12 @@ response_bodies_post_test = {
|
|
| 78 |
def _models_available() -> bool:
|
| 79 |
"""Check if SAM2 model files are downloaded."""
|
| 80 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
from samgis_core.prediction_api.model_registry import verify_download
|
| 82 |
|
| 83 |
variant = os.getenv("MODEL_VARIANT", "sam2.1_hiera_base_plus_uint8")
|
|
@@ -182,6 +189,23 @@ class TestFastapiApp(unittest.TestCase):
|
|
| 182 |
"Less than 2 geometries within the Shapely geometry from the geojson"
|
| 183 |
)
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
@patch.object(time, "time")
|
| 186 |
@patch.object(app, "samexporter_predict")
|
| 187 |
def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
|
|
|
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
import unittest
|
| 6 |
+
from pathlib import Path
|
| 7 |
from unittest.mock import patch
|
| 8 |
|
| 9 |
import pytest
|
|
|
|
| 79 |
def _models_available() -> bool:
|
| 80 |
"""Check if SAM2 model files are downloaded."""
|
| 81 |
try:
|
| 82 |
+
model_folder = os.getenv("MODEL_FOLDER")
|
| 83 |
+
if model_folder is not None:
|
| 84 |
+
folder = Path(model_folder)
|
| 85 |
+
return (folder / "encoder.onnx").exists() and (
|
| 86 |
+
folder / "decoder.onnx"
|
| 87 |
+
).exists()
|
| 88 |
from samgis_core.prediction_api.model_registry import verify_download
|
| 89 |
|
| 90 |
variant = os.getenv("MODEL_VARIANT", "sam2.1_hiera_base_plus_uint8")
|
|
|
|
| 189 |
"Less than 2 geometries within the Shapely geometry from the geojson"
|
| 190 |
)
|
| 191 |
|
| 192 |
+
@patch.dict(os.environ, {"MODEL_FOLDER": ""})
|
| 193 |
+
def test_models_available_with_model_folder_valid(self):
|
| 194 |
+
import tempfile
|
| 195 |
+
|
| 196 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 197 |
+
(Path(tmp) / "encoder.onnx").touch()
|
| 198 |
+
(Path(tmp) / "decoder.onnx").touch()
|
| 199 |
+
with patch.dict(os.environ, {"MODEL_FOLDER": tmp}):
|
| 200 |
+
self.assertTrue(_models_available())
|
| 201 |
+
|
| 202 |
+
def test_models_available_with_model_folder_empty_dir(self):
|
| 203 |
+
import tempfile
|
| 204 |
+
|
| 205 |
+
with tempfile.TemporaryDirectory() as tmp:
|
| 206 |
+
with patch.dict(os.environ, {"MODEL_FOLDER": tmp}):
|
| 207 |
+
self.assertFalse(_models_available())
|
| 208 |
+
|
| 209 |
@patch.object(time, "time")
|
| 210 |
@patch.object(app, "samexporter_predict")
|
| 211 |
def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
|
tests/test_resolve_model_folder.py
CHANGED
|
@@ -35,6 +35,7 @@ class TestResolveModelFolder(unittest.TestCase):
|
|
| 35 |
get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_tiny_uint8")
|
| 36 |
self.assertEqual(result, Path("/mock/tiny"))
|
| 37 |
|
|
|
|
| 38 |
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
|
| 39 |
def test_default_variant_when_no_env_vars(self, get_model_dir_mocked):
|
| 40 |
get_model_dir_mocked.return_value = Path("/mock/default")
|
|
|
|
| 35 |
get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_tiny_uint8")
|
| 36 |
self.assertEqual(result, Path("/mock/tiny"))
|
| 37 |
|
| 38 |
+
@patch.dict(os.environ, {}, clear=False)
|
| 39 |
@patch("samgis_core.prediction_api.model_registry.get_model_dir")
|
| 40 |
def test_default_variant_when_no_env_vars(self, get_model_dir_mocked):
|
| 41 |
get_model_dir_mocked.return_value = Path("/mock/default")
|