alessandro trinca tornidor commited on
Commit
b680bf1
·
1 Parent(s): abeae27

test: fix env leak in test_resolve_model_folder, add MODEL_FOLDER tests

Browse files

Wrap test_default_variant_when_no_env_vars with @patch .dict to
auto-restore os.environ on failure. Add _models_available() support
for MODEL_FOLDER env override with two new test cases.

tests/test_app.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import os
4
  import time
5
  import unittest
 
6
  from unittest.mock import patch
7
 
8
  import pytest
@@ -78,6 +79,12 @@ response_bodies_post_test = {
78
  def _models_available() -> bool:
79
  """Check if SAM2 model files are downloaded."""
80
  try:
 
 
 
 
 
 
81
  from samgis_core.prediction_api.model_registry import verify_download
82
 
83
  variant = os.getenv("MODEL_VARIANT", "sam2.1_hiera_base_plus_uint8")
@@ -182,6 +189,23 @@ class TestFastapiApp(unittest.TestCase):
182
  "Less than 2 geometries within the Shapely geometry from the geojson"
183
  )
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  @patch.object(time, "time")
186
  @patch.object(app, "samexporter_predict")
187
  def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
 
3
  import os
4
  import time
5
  import unittest
6
+ from pathlib import Path
7
  from unittest.mock import patch
8
 
9
  import pytest
 
79
  def _models_available() -> bool:
80
  """Check if SAM2 model files are downloaded."""
81
  try:
82
+ model_folder = os.getenv("MODEL_FOLDER")
83
+ if model_folder is not None:
84
+ folder = Path(model_folder)
85
+ return (folder / "encoder.onnx").exists() and (
86
+ folder / "decoder.onnx"
87
+ ).exists()
88
  from samgis_core.prediction_api.model_registry import verify_download
89
 
90
  variant = os.getenv("MODEL_VARIANT", "sam2.1_hiera_base_plus_uint8")
 
189
  "Less than 2 geometries within the Shapely geometry from the geojson"
190
  )
191
 
192
+ @patch.dict(os.environ, {"MODEL_FOLDER": ""})
193
+ def test_models_available_with_model_folder_valid(self):
194
+ import tempfile
195
+
196
+ with tempfile.TemporaryDirectory() as tmp:
197
+ (Path(tmp) / "encoder.onnx").touch()
198
+ (Path(tmp) / "decoder.onnx").touch()
199
+ with patch.dict(os.environ, {"MODEL_FOLDER": tmp}):
200
+ self.assertTrue(_models_available())
201
+
202
+ def test_models_available_with_model_folder_empty_dir(self):
203
+ import tempfile
204
+
205
+ with tempfile.TemporaryDirectory() as tmp:
206
+ with patch.dict(os.environ, {"MODEL_FOLDER": tmp}):
207
+ self.assertFalse(_models_available())
208
+
209
  @patch.object(time, "time")
210
  @patch.object(app, "samexporter_predict")
211
  def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
tests/test_resolve_model_folder.py CHANGED
@@ -35,6 +35,7 @@ class TestResolveModelFolder(unittest.TestCase):
35
  get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_tiny_uint8")
36
  self.assertEqual(result, Path("/mock/tiny"))
37
 
 
38
  @patch("samgis_core.prediction_api.model_registry.get_model_dir")
39
  def test_default_variant_when_no_env_vars(self, get_model_dir_mocked):
40
  get_model_dir_mocked.return_value = Path("/mock/default")
 
35
  get_model_dir_mocked.assert_called_once_with("sam2.1_hiera_tiny_uint8")
36
  self.assertEqual(result, Path("/mock/tiny"))
37
 
38
+ @patch.dict(os.environ, {}, clear=False)
39
  @patch("samgis_core.prediction_api.model_registry.get_model_dir")
40
  def test_default_variant_when_no_env_vars(self, get_model_dir_mocked):
41
  get_model_dir_mocked.return_value = Path("/mock/default")