Spaces:
Build error
Build error
| """ | |
| Functions for explaining classifiers that use Image data. | |
| """ | |
| import copy | |
| from functools import partial | |
| import numpy as np | |
| import sklearn | |
| import sklearn.preprocessing | |
| from sklearn.utils import check_random_state | |
| from skimage.color import gray2rgb | |
| from tqdm.auto import tqdm | |
| from . import lime_base | |
| from .wrappers.scikit_image import SegmentationAlgorithm | |
| class ImageExplanation(object): | |
| def __init__(self, image, segments): | |
| """Init function. | |
| Args: | |
| image: 3d numpy array | |
| segments: 2d numpy array, with the output from skimage.segmentation | |
| """ | |
| self.image = image | |
| self.segments = segments | |
| self.intercept = {} | |
| self.local_exp = {} | |
| self.local_pred = None | |
| def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, | |
| num_features=5, min_weight=0.): | |
| """Init function. | |
| Args: | |
| label: label to explain | |
| positive_only: if True, only take superpixels that positively contribute to | |
| the prediction of the label. | |
| negative_only: if True, only take superpixels that negatively contribute to | |
| the prediction of the label. If false, and so is positive_only, then both | |
| negativey and positively contributions will be taken. | |
| Both can't be True at the same time | |
| hide_rest: if True, make the non-explanation part of the return | |
| image gray | |
| num_features: number of superpixels to include in explanation | |
| min_weight: minimum weight of the superpixels to include in explanation | |
| Returns: | |
| (image, mask), where image is a 3d numpy array and mask is a 2d | |
| numpy array that can be used with | |
| skimage.segmentation.mark_boundaries | |
| """ | |
| if label not in self.local_exp: | |
| raise KeyError('Label not in explanation') | |
| if positive_only & negative_only: | |
| raise ValueError("Positive_only and negative_only cannot be true at the same time.") | |
| segments = self.segments | |
| image = self.image | |
| exp = self.local_exp[label] | |
| mask = np.zeros(segments.shape, segments.dtype) | |
| if hide_rest: | |
| temp = np.zeros(self.image.shape) | |
| else: | |
| temp = self.image.copy() | |
| if positive_only: | |
| fs = [x[0] for x in exp | |
| if x[1] > 0 and x[1] > min_weight][:num_features] | |
| if negative_only: | |
| fs = [x[0] for x in exp | |
| if x[1] < 0 and abs(x[1]) > min_weight][:num_features] | |
| if positive_only or negative_only: | |
| for f in fs: | |
| temp[segments == f] = image[segments == f].copy() | |
| mask[segments == f] = 1 | |
| return temp, mask | |
| else: | |
| for f, w in exp[:num_features]: | |
| if np.abs(w) < min_weight: | |
| continue | |
| c = 0 if w < 0 else 1 | |
| mask[segments == f] = -1 if w < 0 else 1 | |
| temp[segments == f] = image[segments == f].copy() | |
| temp[segments == f, c] = np.max(image) | |
| return temp, mask | |
| class LimeImageExplainer(object): | |
| """Explains predictions on Image (i.e. matrix) data. | |
| For numerical features, perturb them by sampling from a Normal(0,1) and | |
| doing the inverse operation of mean-centering and scaling, according to the | |
| means and stds in the training data. For categorical features, perturb by | |
| sampling according to the training distribution, and making a binary | |
| feature that is 1 when the value is the same as the instance being | |
| explained.""" | |
| def __init__(self, kernel_width=.25, kernel=None, verbose=False, | |
| feature_selection='auto', random_state=None): | |
| """Init function. | |
| Args: | |
| kernel_width: kernel width for the exponential kernel. | |
| If None, defaults to sqrt(number of columns) * 0.75. | |
| kernel: similarity kernel that takes euclidean distances and kernel | |
| width as input and outputs weights in (0,1). If None, defaults to | |
| an exponential kernel. | |
| verbose: if true, print local prediction values from linear model | |
| feature_selection: feature selection method. can be | |
| 'forward_selection', 'lasso_path', 'none' or 'auto'. | |
| See function 'explain_instance_with_data' in lime_base.py for | |
| details on what each of the options does. | |
| random_state: an integer or numpy.RandomState that will be used to | |
| generate random numbers. If None, the random state will be | |
| initialized using the internal numpy seed. | |
| """ | |
| kernel_width = float(kernel_width) | |
| if kernel is None: | |
| def kernel(d, kernel_width): | |
| return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) | |
| kernel_fn = partial(kernel, kernel_width=kernel_width) | |
| self.random_state = check_random_state(random_state) | |
| self.feature_selection = feature_selection | |
| self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state) | |
| ### Custom function to acquire segmentation only, same as in the explain_instance() function | |
| def acquireSegmOnly(self, img): | |
| random_seed = self.random_state.randint(0, high=1000) | |
| segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, | |
| max_dist=200, ratio=0.2, | |
| random_seed=random_seed) | |
| segments = segmentation_fn(img) | |
| return segments | |
| def explain_instance(self, image, inputImg, classifier_fn, labels=(1,), | |
| hide_color=None, | |
| top_labels=5, num_features=100000, num_samples=1000, | |
| batch_size=10, | |
| segmentation_fn=None, | |
| distance_metric='cosine', | |
| model_regressor=None, | |
| random_seed=None, | |
| squaredSegm=None, | |
| loadedSegmData=None): | |
| """Generates explanations for a prediction. | |
| First, we generate neighborhood data by randomly perturbing features | |
| from the instance (see __data_inverse). We then learn locally weighted | |
| linear models on this neighborhood data to explain each of the classes | |
| in an interpretable way (see lime_base.py). | |
| Args: | |
| image: 3 dimension RGB image. If this is only two dimensional, | |
| we will assume it's a grayscale image and call gray2rgb. | |
| classifier_fn: classifier prediction probability function, which | |
| takes a numpy array and outputs prediction probabilities. For | |
| ScikitClassifiers , this is classifier.predict_proba. | |
| labels: iterable with labels to be explained. | |
| hide_color: TODO | |
| top_labels: if not None, ignore labels and produce explanations for | |
| the K labels with highest prediction probabilities, where K is | |
| this parameter. | |
| num_features: maximum number of features present in explanation | |
| num_samples: size of the neighborhood to learn the linear model | |
| batch_size: TODO | |
| distance_metric: the distance metric to use for weights. | |
| model_regressor: sklearn regressor to use in explanation. Defaults | |
| to Ridge regression in LimeBase. Must have model_regressor.coef_ | |
| and 'sample_weight' as a parameter to model_regressor.fit() | |
| segmentation_fn: SegmentationAlgorithm, wrapped skimage | |
| segmentation function | |
| random_seed: integer used as random seed for the segmentation | |
| algorithm. If None, a random integer, between 0 and 1000, | |
| will be generated using the internal random number generator. | |
| squaredSegm: integer or None (default): | |
| Returns: | |
| An ImageExplanation object (see lime_image.py) with the corresponding | |
| explanations. | |
| """ | |
| if len(image.shape) == 2: | |
| image = gray2rgb(image) | |
| if random_seed is None: | |
| random_seed = self.random_state.randint(0, high=1000) | |
| if squaredSegm == 4: | |
| segments = np.zeros((image.shape[0], image.shape[1]), dtype=np.int64) | |
| imgW = image.shape[1] | |
| halfW1 = 1*imgW//4 | |
| halfW2 = 2*imgW//4 | |
| halfW3 = 3*imgW//4 | |
| segments[:,0:halfW1] = 0 | |
| segments[:,halfW1:halfW2] = 1 | |
| segments[:,halfW2:halfW3] = 2 | |
| segments[:,halfW3:imgW] = 3 | |
| elif squaredSegm == -2: ### Use to load custom resized segm data | |
| segments = loadedSegmData | |
| else: | |
| if segmentation_fn is None: | |
| segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, | |
| max_dist=200, ratio=0.2, | |
| random_seed=random_seed) | |
| try: | |
| segments = segmentation_fn(image) | |
| except ValueError as e: | |
| raise e | |
| fudged_image = image.copy() | |
| if hide_color is None: | |
| for x in np.unique(segments): | |
| fudged_image[segments == x] = ( | |
| np.mean(image[segments == x][:, 0]), | |
| np.mean(image[segments == x][:, 1]), | |
| np.mean(image[segments == x][:, 2])) | |
| else: | |
| fudged_image[:] = hide_color | |
| top = labels | |
| data, labels = self.data_labels(image, inputImg, fudged_image, segments, | |
| classifier_fn, num_samples, | |
| batch_size=batch_size) | |
| distances = sklearn.metrics.pairwise_distances( | |
| data, | |
| data[0].reshape(1, -1), | |
| metric=distance_metric | |
| ).ravel() | |
| ret_exp = ImageExplanation(image, segments) | |
| if top_labels: | |
| top = np.argsort(labels[0])[-top_labels:] | |
| ret_exp.top_labels = list(top) | |
| ret_exp.top_labels.reverse() | |
| for label in top: | |
| (ret_exp.intercept[label], | |
| ret_exp.local_exp[label], | |
| ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( | |
| data, labels, distances, label, num_features, | |
| model_regressor=model_regressor, | |
| feature_selection=self.feature_selection) | |
| return ret_exp | |
| def data_labels(self, | |
| image, | |
| inputImg, | |
| fudged_image, | |
| segments, | |
| classifier_fn, | |
| num_samples, | |
| batch_size=10): | |
| """Generates images and predictions in the neighborhood of this image. | |
| Args: | |
| image: 3d numpy array, the image | |
| fudged_image: 3d numpy array, image to replace original image when | |
| superpixel is turned off | |
| segments: segmentation of the image | |
| classifier_fn: function that takes a list of images and returns a | |
| matrix of prediction probabilities | |
| num_samples: size of the neighborhood to learn the linear model | |
| batch_size: classifier_fn will be called on batches of this size. | |
| Returns: | |
| A tuple (data, labels), where: | |
| data: dense num_samples * num_superpixels | |
| labels: prediction probabilities matrix | |
| """ | |
| n_features = np.unique(segments).shape[0] | |
| data = self.random_state.randint(0, 2, num_samples * n_features)\ | |
| .reshape((num_samples, n_features)) | |
| labels = [] | |
| data[0, :] = 1 | |
| imgs = [] | |
| # print("data new shape: ", data.shape) | |
| # assert(False) | |
| # for row in tqdm(data): | |
| for row in data: | |
| temp = copy.deepcopy(image) | |
| zeros = np.where(row == 0)[0] | |
| mask = np.zeros(segments.shape).astype(bool) | |
| for z in zeros: | |
| mask[segments == z] = True | |
| temp[mask] = fudged_image[mask] | |
| imgs.append(temp) | |
| if len(imgs) == batch_size: | |
| preds = classifier_fn(inputImg) | |
| preds = preds.cpu().detach().numpy() | |
| labels.extend(preds) | |
| imgs = [] | |
| if len(imgs) > 0: | |
| preds = classifier_fn(inputImg) | |
| preds = preds.cpu().detach().numpy() | |
| labels.extend(preds) | |
| return data, np.array(labels) | |