Spaces:
Build error
Build error
| import unittest | |
| from lime.wrappers.scikit_image import BaseWrapper | |
| from lime.wrappers.scikit_image import SegmentationAlgorithm | |
| from skimage.segmentation import quickshift | |
| from skimage.data import chelsea | |
| from skimage.util import img_as_float | |
| import numpy as np | |
| class TestBaseWrapper(unittest.TestCase): | |
| def test_base_wrapper(self): | |
| obj_with_params = BaseWrapper(a=10, b='message') | |
| obj_without_params = BaseWrapper() | |
| def foo_fn(): | |
| return 'bar' | |
| obj_with_fn = BaseWrapper(foo_fn) | |
| self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'}) | |
| self.assertEqual(obj_without_params.target_params, {}) | |
| self.assertEqual(obj_with_fn.target_fn(), 'bar') | |
| def test__check_params(self): | |
| def bar_fn(a): | |
| return str(a) | |
| class Pipo(): | |
| def __init__(self): | |
| self.name = 'pipo' | |
| def __call__(self, message): | |
| return message | |
| pipo = Pipo() | |
| obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message') | |
| obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message') | |
| obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid') | |
| # target_fn is not a callable or function/method | |
| with self.assertRaises(AttributeError): | |
| obj_with_invalid_fn._check_params('fn_name') | |
| # parameters is not in target_fn args | |
| with self.assertRaises(ValueError): | |
| obj_with_valid_fn._check_params(['c']) | |
| obj_with_valid_callable_fn._check_params(['e']) | |
| # params is in target_fn args | |
| try: | |
| obj_with_valid_fn._check_params(['a']) | |
| obj_with_valid_callable_fn._check_params(['message']) | |
| except Exception: | |
| self.fail("_check_params() raised an unexpected exception") | |
| # params is not a dict or list | |
| with self.assertRaises(TypeError): | |
| obj_with_valid_fn._check_params(None) | |
| with self.assertRaises(TypeError): | |
| obj_with_valid_fn._check_params('param_name') | |
| def test_set_params(self): | |
| class Pipo(): | |
| def __init__(self): | |
| self.name = 'pipo' | |
| def __call__(self, message): | |
| return message | |
| pipo = Pipo() | |
| obj = BaseWrapper(pipo) | |
| # argument is set accordingly | |
| obj.set_params(message='OK') | |
| self.assertEqual(obj.target_params, {'message': 'OK'}) | |
| self.assertEqual(obj.target_fn(**obj.target_params), 'OK') | |
| # invalid argument is passed | |
| try: | |
| obj = BaseWrapper(Pipo()) | |
| obj.set_params(invalid='KO') | |
| except Exception: | |
| self.assertEqual(obj.target_params, {}) | |
| def test_filter_params(self): | |
| # right arguments are kept and wrong dismmissed | |
| def baz_fn(a, b, c=True): | |
| if c: | |
| return a + b | |
| else: | |
| return a | |
| obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000) | |
| self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100}) | |
| # target_params is overriden using 'override' argument | |
| self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}), | |
| {'a': 10, 'b': 100, 'c': False}) | |
| class TestSegmentationAlgorithm(unittest.TestCase): | |
| def test_instanciate_segmentation_algorithm(self): | |
| img = img_as_float(chelsea()[::2, ::2]) | |
| # wrapped functions provide the same result | |
| fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6, | |
| ratio=0.5, random_seed=133) | |
| fn_result = fn(img) | |
| original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5, | |
| random_seed=133) | |
| # same segments | |
| self.assertTrue(np.array_equal(fn_result, original_result)) | |
| def test_instanciate_slic(self): | |
| pass | |
| def test_instanciate_felzenszwalb(self): | |
| pass | |
| if __name__ == '__main__': | |
| unittest.main() | |