Spaces:
Build error
Build error
| import types | |
| from lime.utils.generic_utils import has_arg | |
| from skimage.segmentation import felzenszwalb, slic, quickshift | |
| class BaseWrapper(object): | |
| """Base class for LIME Scikit-Image wrapper | |
| Args: | |
| target_fn: callable function or class instance | |
| target_params: dict, parameters to pass to the target_fn | |
| 'target_params' takes parameters required to instanciate the | |
| desired Scikit-Image class/model | |
| """ | |
| def __init__(self, target_fn=None, **target_params): | |
| self.target_fn = target_fn | |
| self.target_params = target_params | |
| self.target_fn = target_fn | |
| self.target_params = target_params | |
| def _check_params(self, parameters): | |
| """Checks for mistakes in 'parameters' | |
| Args : | |
| parameters: dict, parameters to be checked | |
| Raises : | |
| ValueError: if any parameter is not a valid argument for the target function | |
| or the target function is not defined | |
| TypeError: if argument parameters is not iterable | |
| """ | |
| a_valid_fn = [] | |
| if self.target_fn is None: | |
| if callable(self): | |
| a_valid_fn.append(self.__call__) | |
| else: | |
| raise TypeError('invalid argument: tested object is not callable,\ | |
| please provide a valid target_fn') | |
| elif isinstance(self.target_fn, types.FunctionType) \ | |
| or isinstance(self.target_fn, types.MethodType): | |
| a_valid_fn.append(self.target_fn) | |
| else: | |
| a_valid_fn.append(self.target_fn.__call__) | |
| if not isinstance(parameters, str): | |
| for p in parameters: | |
| for fn in a_valid_fn: | |
| if has_arg(fn, p): | |
| pass | |
| else: | |
| raise ValueError('{} is not a valid parameter'.format(p)) | |
| else: | |
| raise TypeError('invalid argument: list or dictionnary expected') | |
| def set_params(self, **params): | |
| """Sets the parameters of this estimator. | |
| Args: | |
| **params: Dictionary of parameter names mapped to their values. | |
| Raises : | |
| ValueError: if any parameter is not a valid argument | |
| for the target function | |
| """ | |
| self._check_params(params) | |
| self.target_params = params | |
| def filter_params(self, fn, override=None): | |
| """Filters `target_params` and return those in `fn`'s arguments. | |
| Args: | |
| fn : arbitrary function | |
| override: dict, values to override target_params | |
| Returns: | |
| result : dict, dictionary containing variables | |
| in both target_params and fn's arguments. | |
| """ | |
| override = override or {} | |
| result = {} | |
| for name, value in self.target_params.items(): | |
| if has_arg(fn, name): | |
| result.update({name: value}) | |
| result.update(override) | |
| return result | |
| class SegmentationAlgorithm(BaseWrapper): | |
| """ Define the image segmentation function based on Scikit-Image | |
| implementation and a set of provided parameters | |
| Args: | |
| algo_type: string, segmentation algorithm among the following: | |
| 'quickshift', 'slic', 'felzenszwalb' | |
| target_params: dict, algorithm parameters (valid model paramters | |
| as define in Scikit-Image documentation) | |
| """ | |
| def __init__(self, algo_type, **target_params): | |
| self.algo_type = algo_type | |
| if (self.algo_type == 'quickshift'): | |
| BaseWrapper.__init__(self, quickshift, **target_params) | |
| kwargs = self.filter_params(quickshift) | |
| self.set_params(**kwargs) | |
| elif (self.algo_type == 'felzenszwalb'): | |
| BaseWrapper.__init__(self, felzenszwalb, **target_params) | |
| kwargs = self.filter_params(felzenszwalb) | |
| self.set_params(**kwargs) | |
| elif (self.algo_type == 'slic'): | |
| BaseWrapper.__init__(self, slic, **target_params) | |
| kwargs = self.filter_params(slic) | |
| self.set_params(**kwargs) | |
| def __call__(self, *args): | |
| return self.target_fn(args[0], **self.target_params) | |