Spaces:
Build error
Build error
| import re | |
| import unittest | |
| import sklearn # noqa | |
| from sklearn.datasets import fetch_20newsgroups | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics import f1_score | |
| from sklearn.naive_bayes import MultinomialNB | |
| from sklearn.pipeline import make_pipeline | |
| import numpy as np | |
| from lime.lime_text import LimeTextExplainer | |
| from lime.lime_text import IndexedCharacters, IndexedString | |
| class TestLimeText(unittest.TestCase): | |
| def test_lime_text_explainer_good_regressor(self): | |
| categories = ['alt.atheism', 'soc.religion.christian'] | |
| newsgroups_train = fetch_20newsgroups(subset='train', | |
| categories=categories) | |
| newsgroups_test = fetch_20newsgroups(subset='test', | |
| categories=categories) | |
| class_names = ['atheism', 'christian'] | |
| vectorizer = TfidfVectorizer(lowercase=False) | |
| train_vectors = vectorizer.fit_transform(newsgroups_train.data) | |
| test_vectors = vectorizer.transform(newsgroups_test.data) | |
| nb = MultinomialNB(alpha=.01) | |
| nb.fit(train_vectors, newsgroups_train.target) | |
| pred = nb.predict(test_vectors) | |
| f1_score(newsgroups_test.target, pred, average='weighted') | |
| c = make_pipeline(vectorizer, nb) | |
| explainer = LimeTextExplainer(class_names=class_names) | |
| idx = 83 | |
| exp = explainer.explain_instance(newsgroups_test.data[idx], | |
| c.predict_proba, num_features=6) | |
| self.assertIsNotNone(exp) | |
| self.assertEqual(6, len(exp.as_list())) | |
| def test_lime_text_tabular_equal_random_state(self): | |
| categories = ['alt.atheism', 'soc.religion.christian'] | |
| newsgroups_train = fetch_20newsgroups(subset='train', | |
| categories=categories) | |
| newsgroups_test = fetch_20newsgroups(subset='test', | |
| categories=categories) | |
| class_names = ['atheism', 'christian'] | |
| vectorizer = TfidfVectorizer(lowercase=False) | |
| train_vectors = vectorizer.fit_transform(newsgroups_train.data) | |
| test_vectors = vectorizer.transform(newsgroups_test.data) | |
| nb = MultinomialNB(alpha=.01) | |
| nb.fit(train_vectors, newsgroups_train.target) | |
| pred = nb.predict(test_vectors) | |
| f1_score(newsgroups_test.target, pred, average='weighted') | |
| c = make_pipeline(vectorizer, nb) | |
| explainer = LimeTextExplainer(class_names=class_names, random_state=10) | |
| exp_1 = explainer.explain_instance(newsgroups_test.data[83], | |
| c.predict_proba, num_features=6) | |
| explainer = LimeTextExplainer(class_names=class_names, random_state=10) | |
| exp_2 = explainer.explain_instance(newsgroups_test.data[83], | |
| c.predict_proba, num_features=6) | |
| self.assertTrue(exp_1.as_map() == exp_2.as_map()) | |
| def test_lime_text_tabular_not_equal_random_state(self): | |
| categories = ['alt.atheism', 'soc.religion.christian'] | |
| newsgroups_train = fetch_20newsgroups(subset='train', | |
| categories=categories) | |
| newsgroups_test = fetch_20newsgroups(subset='test', | |
| categories=categories) | |
| class_names = ['atheism', 'christian'] | |
| vectorizer = TfidfVectorizer(lowercase=False) | |
| train_vectors = vectorizer.fit_transform(newsgroups_train.data) | |
| test_vectors = vectorizer.transform(newsgroups_test.data) | |
| nb = MultinomialNB(alpha=.01) | |
| nb.fit(train_vectors, newsgroups_train.target) | |
| pred = nb.predict(test_vectors) | |
| f1_score(newsgroups_test.target, pred, average='weighted') | |
| c = make_pipeline(vectorizer, nb) | |
| explainer = LimeTextExplainer( | |
| class_names=class_names, random_state=10) | |
| exp_1 = explainer.explain_instance(newsgroups_test.data[83], | |
| c.predict_proba, num_features=6) | |
| explainer = LimeTextExplainer( | |
| class_names=class_names, random_state=20) | |
| exp_2 = explainer.explain_instance(newsgroups_test.data[83], | |
| c.predict_proba, num_features=6) | |
| self.assertFalse(exp_1.as_map() == exp_2.as_map()) | |
| def test_indexed_characters_bow(self): | |
| s = 'Please, take your time' | |
| inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm'] | |
| positions = [[0], [1], [2, 5, 11, 21], [3, 9], | |
| [4], [6], [7, 12, 17], [8, 18], [10], | |
| [13], [14], [15], [16], [19], [20]] | |
| ic = IndexedCharacters(s) | |
| self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) | |
| self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) | |
| self.assertTrue(ic.inverse_vocab == inverse_vocab) | |
| self.assertTrue(ic.positions == positions) | |
| def test_indexed_characters_not_bow(self): | |
| s = 'Please, take your time' | |
| ic = IndexedCharacters(s, bow=False) | |
| self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) | |
| self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) | |
| self.assertTrue(ic.inverse_vocab == list(s)) | |
| self.assertTrue(np.array_equal(ic.positions, np.arange(len(s)))) | |
| def test_indexed_string_regex(self): | |
| s = 'Please, take your time. Please' | |
| tokenized_string = np.array( | |
| ['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please']) | |
| inverse_vocab = ['Please', 'take', 'your', 'time'] | |
| start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24] | |
| positions = [[0, 8], [2], [4], [6]] | |
| indexed_string = IndexedString(s) | |
| self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) | |
| self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) | |
| self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) | |
| self.assertTrue(np.array_equal(indexed_string.positions, positions)) | |
| def test_indexed_string_callable(self): | |
| s = 'aabbccddaa' | |
| def tokenizer(string): | |
| return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)] | |
| tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa']) | |
| inverse_vocab = ['aa', 'bb', 'cc', 'dd'] | |
| start_positions = [0, 2, 4, 6, 8] | |
| positions = [[0, 4], [1], [2], [3]] | |
| indexed_string = IndexedString(s, tokenizer) | |
| self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) | |
| self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) | |
| self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) | |
| self.assertTrue(np.array_equal(indexed_string.positions, positions)) | |
| def test_indexed_string_inverse_removing_tokenizer(self): | |
| s = 'This is a good movie. This, it is a great movie.' | |
| def tokenizer(string): | |
| return re.split(r'(?:\W+)|$', string) | |
| indexed_string = IndexedString(s, tokenizer) | |
| self.assertEqual(s, indexed_string.inverse_removing([])) | |
| def test_indexed_string_inverse_removing_regex(self): | |
| s = 'This is a good movie. This is a great movie' | |
| indexed_string = IndexedString(s) | |
| self.assertEqual(s, indexed_string.inverse_removing([])) | |
| if __name__ == '__main__': | |
| unittest.main() | |