Spaces:
Build error
Build error
| """ | |
| Explanation class, with visualization functions. | |
| """ | |
| from io import open | |
| import os | |
| import os.path | |
| import json | |
| import string | |
| import numpy as np | |
| from .exceptions import LimeError | |
| from sklearn.utils import check_random_state | |
| def id_generator(size=15, random_state=None): | |
| """Helper function to generate random div ids. This is useful for embedding | |
| HTML into ipython notebooks.""" | |
| chars = list(string.ascii_uppercase + string.digits) | |
| return ''.join(random_state.choice(chars, size, replace=True)) | |
| class DomainMapper(object): | |
| """Class for mapping features to the specific domain. | |
| The idea is that there would be a subclass for each domain (text, tables, | |
| images, etc), so that we can have a general Explanation class, and separate | |
| out the specifics of visualizing features in here. | |
| """ | |
| def __init__(self): | |
| pass | |
| def map_exp_ids(self, exp, **kwargs): | |
| """Maps the feature ids to concrete names. | |
| Default behaviour is the identity function. Subclasses can implement | |
| this as they see fit. | |
| Args: | |
| exp: list of tuples [(id, weight), (id,weight)] | |
| kwargs: optional keyword arguments | |
| Returns: | |
| exp: list of tuples [(name, weight), (name, weight)...] | |
| """ | |
| return exp | |
| def visualize_instance_html(self, | |
| exp, | |
| label, | |
| div_name, | |
| exp_object_name, | |
| **kwargs): | |
| """Produces html for visualizing the instance. | |
| Default behaviour does nothing. Subclasses can implement this as they | |
| see fit. | |
| Args: | |
| exp: list of tuples [(id, weight), (id,weight)] | |
| label: label id (integer) | |
| div_name: name of div object to be used for rendering(in js) | |
| exp_object_name: name of js explanation object | |
| kwargs: optional keyword arguments | |
| Returns: | |
| js code for visualizing the instance | |
| """ | |
| return '' | |
| class Explanation(object): | |
| """Object returned by explainers.""" | |
| def __init__(self, | |
| domain_mapper, | |
| mode='classification', | |
| class_names=None, | |
| random_state=None): | |
| """ | |
| Initializer. | |
| Args: | |
| domain_mapper: must inherit from DomainMapper class | |
| type: "classification" or "regression" | |
| class_names: list of class names (only used for classification) | |
| random_state: an integer or numpy.RandomState that will be used to | |
| generate random numbers. If None, the random state will be | |
| initialized using the internal numpy seed. | |
| """ | |
| self.random_state = random_state | |
| self.mode = mode | |
| self.domain_mapper = domain_mapper | |
| self.local_exp = {} | |
| self.intercept = {} | |
| self.score = None | |
| self.local_pred = None | |
| if mode == 'classification': | |
| self.class_names = class_names | |
| self.top_labels = None | |
| self.predict_proba = None | |
| elif mode == 'regression': | |
| self.class_names = ['negative', 'positive'] | |
| self.predicted_value = None | |
| self.min_value = 0.0 | |
| self.max_value = 1.0 | |
| self.dummy_label = 1 | |
| else: | |
| raise LimeError('Invalid explanation mode "{}". ' | |
| 'Should be either "classification" ' | |
| 'or "regression".'.format(mode)) | |
| def available_labels(self): | |
| """ | |
| Returns the list of classification labels for which we have any explanations. | |
| """ | |
| try: | |
| assert self.mode == "classification" | |
| except AssertionError: | |
| raise NotImplementedError('Not supported for regression explanations.') | |
| else: | |
| ans = self.top_labels if self.top_labels else self.local_exp.keys() | |
| return list(ans) | |
| def as_list(self, label=1, **kwargs): | |
| """Returns the explanation as a list. | |
| Args: | |
| label: desired label. If you ask for a label for which an | |
| explanation wasn't computed, will throw an exception. | |
| Will be ignored for regression explanations. | |
| kwargs: keyword arguments, passed to domain_mapper | |
| Returns: | |
| list of tuples (representation, weight), where representation is | |
| given by domain_mapper. Weight is a float. | |
| """ | |
| label_to_use = label if self.mode == "classification" else self.dummy_label | |
| ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs) | |
| ans = [(x[0], float(x[1])) for x in ans] | |
| return ans | |
| def as_map(self): | |
| """Returns the map of explanations. | |
| Returns: | |
| Map from label to list of tuples (feature_id, weight). | |
| """ | |
| return self.local_exp | |
| def as_pyplot_figure(self, label=1, **kwargs): | |
| """Returns the explanation as a pyplot figure. | |
| Will throw an error if you don't have matplotlib installed | |
| Args: | |
| label: desired label. If you ask for a label for which an | |
| explanation wasn't computed, will throw an exception. | |
| Will be ignored for regression explanations. | |
| kwargs: keyword arguments, passed to domain_mapper | |
| Returns: | |
| pyplot figure (barchart). | |
| """ | |
| import matplotlib.pyplot as plt | |
| exp = self.as_list(label=label, **kwargs) | |
| fig = plt.figure() | |
| vals = [x[1] for x in exp] | |
| names = [x[0] for x in exp] | |
| vals.reverse() | |
| names.reverse() | |
| colors = ['green' if x > 0 else 'red' for x in vals] | |
| pos = np.arange(len(exp)) + .5 | |
| plt.barh(pos, vals, align='center', color=colors) | |
| plt.yticks(pos, names) | |
| if self.mode == "classification": | |
| title = 'Local explanation for class %s' % self.class_names[label] | |
| else: | |
| title = 'Local explanation' | |
| plt.title(title) | |
| return fig | |
| def show_in_notebook(self, | |
| labels=None, | |
| predict_proba=True, | |
| show_predicted_value=True, | |
| **kwargs): | |
| """Shows html explanation in ipython notebook. | |
| See as_html() for parameters. | |
| This will throw an error if you don't have IPython installed""" | |
| from IPython.core.display import display, HTML | |
| display(HTML(self.as_html(labels=labels, | |
| predict_proba=predict_proba, | |
| show_predicted_value=show_predicted_value, | |
| **kwargs))) | |
| def save_to_file(self, | |
| file_path, | |
| labels=None, | |
| predict_proba=True, | |
| show_predicted_value=True, | |
| **kwargs): | |
| """Saves html explanation to file. . | |
| Params: | |
| file_path: file to save explanations to | |
| See as_html() for additional parameters. | |
| """ | |
| file_ = open(file_path, 'w', encoding='utf8') | |
| file_.write(self.as_html(labels=labels, | |
| predict_proba=predict_proba, | |
| show_predicted_value=show_predicted_value, | |
| **kwargs)) | |
| file_.close() | |
| def as_html(self, | |
| labels=None, | |
| predict_proba=True, | |
| show_predicted_value=True, | |
| **kwargs): | |
| """Returns the explanation as an html page. | |
| Args: | |
| labels: desired labels to show explanations for (as barcharts). | |
| If you ask for a label for which an explanation wasn't | |
| computed, will throw an exception. If None, will show | |
| explanations for all available labels. (only used for classification) | |
| predict_proba: if true, add barchart with prediction probabilities | |
| for the top classes. (only used for classification) | |
| show_predicted_value: if true, add barchart with expected value | |
| (only used for regression) | |
| kwargs: keyword arguments, passed to domain_mapper | |
| Returns: | |
| code for an html page, including javascript includes. | |
| """ | |
| def jsonize(x): | |
| return json.dumps(x, ensure_ascii=False) | |
| if labels is None and self.mode == "classification": | |
| labels = self.available_labels() | |
| this_dir, _ = os.path.split(__file__) | |
| bundle = open(os.path.join(this_dir, 'bundle.js'), | |
| encoding="utf8").read() | |
| out = u'''<html> | |
| <meta http-equiv="content-type" content="text/html; charset=UTF8"> | |
| <head><script>%s </script></head><body>''' % bundle | |
| random_id = id_generator(size=15, random_state=check_random_state(self.random_state)) | |
| out += u''' | |
| <div class="lime top_div" id="top_div%s"></div> | |
| ''' % random_id | |
| predict_proba_js = '' | |
| if self.mode == "classification" and predict_proba: | |
| predict_proba_js = u''' | |
| var pp_div = top_div.append('div') | |
| .classed('lime predict_proba', true); | |
| var pp_svg = pp_div.append('svg').style('width', '100%%'); | |
| var pp = new lime.PredictProba(pp_svg, %s, %s); | |
| ''' % (jsonize([str(x) for x in self.class_names]), | |
| jsonize(list(self.predict_proba.astype(float)))) | |
| predict_value_js = '' | |
| if self.mode == "regression" and show_predicted_value: | |
| # reference self.predicted_value | |
| # (svg, predicted_value, min_value, max_value) | |
| predict_value_js = u''' | |
| var pp_div = top_div.append('div') | |
| .classed('lime predicted_value', true); | |
| var pp_svg = pp_div.append('svg').style('width', '100%%'); | |
| var pp = new lime.PredictedValue(pp_svg, %s, %s, %s); | |
| ''' % (jsonize(float(self.predicted_value)), | |
| jsonize(float(self.min_value)), | |
| jsonize(float(self.max_value))) | |
| exp_js = '''var exp_div; | |
| var exp = new lime.Explanation(%s); | |
| ''' % (jsonize([str(x) for x in self.class_names])) | |
| if self.mode == "classification": | |
| for label in labels: | |
| exp = jsonize(self.as_list(label)) | |
| exp_js += u''' | |
| exp_div = top_div.append('div').classed('lime explanation', true); | |
| exp.show(%s, %d, exp_div); | |
| ''' % (exp, label) | |
| else: | |
| exp = jsonize(self.as_list()) | |
| exp_js += u''' | |
| exp_div = top_div.append('div').classed('lime explanation', true); | |
| exp.show(%s, %s, exp_div); | |
| ''' % (exp, self.dummy_label) | |
| raw_js = '''var raw_div = top_div.append('div');''' | |
| if self.mode == "classification": | |
| html_data = self.local_exp[labels[0]] | |
| else: | |
| html_data = self.local_exp[self.dummy_label] | |
| raw_js += self.domain_mapper.visualize_instance_html( | |
| html_data, | |
| labels[0] if self.mode == "classification" else self.dummy_label, | |
| 'raw_div', | |
| 'exp', | |
| **kwargs) | |
| out += u''' | |
| <script> | |
| var top_div = d3.select('#top_div%s').classed('lime top_div', true); | |
| %s | |
| %s | |
| %s | |
| %s | |
| </script> | |
| ''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js) | |
| out += u'</body></html>' | |
| return out | |