Spaces:
Runtime error
Runtime error
| import os | |
| from transformers import BertTokenizer | |
| from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
| from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification | |
| from transformers import BertTokenizer | |
| from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
| from transformers import AutoTokenizer | |
| from captum.attr import visualization | |
| import spacy | |
| import torch | |
| from IPython.display import Image, HTML, display | |
| from sequenceoutput.modeling_output import SequenceClassifierOutput | |
| model = BertForSequenceClassification.from_pretrained("./BERT").to("cuda") | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("./BERT") | |
| # initialize the explanations generator | |
| explanations = Generator(model) | |
| classifications = ["NEGATIVE", "POSITIVE"] | |
| # encode a sentence | |
| text_batch = ["I hate that I love you."] | |
| encoding = tokenizer(text_batch, return_tensors='pt') | |
| input_ids = encoding['input_ids'].to("cuda") | |
| attention_mask = encoding['attention_mask'].to("cuda") | |
| # true class is positive - 1 | |
| true_class = 1 | |
| # generate an explanation for the input | |
| target_class = 0 | |
| expl = \ | |
| explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, index=target_class)[0] | |
| # normalize scores | |
| expl = (expl - expl.min()) / (expl.max() - expl.min()) | |
| # get the model classification | |
| output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) | |
| classification = output.argmax(dim=-1).item() | |
| # get class name | |
| class_name = classifications[target_class] | |
| # if the classification is negative, higher explanation scores are more negative | |
| # flip for visualization | |
| if class_name == "NEGATIVE": | |
| expl *= (-1) | |
| token_importance = {} | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) | |
| for i in range(len(tokens)): | |
| token_importance[tokens[i]] = expl[i].item() | |
| vis_data_records = [visualization.VisualizationDataRecord( | |
| expl, | |
| output[0][classification], | |
| classification, | |
| true_class, | |
| true_class, | |
| 1, | |
| tokens, | |
| 1)] | |
| html1 = visualization.visualize_text(vis_data_records) | |
| # print(token_importance, html1) | |
| # with open('bert-xai.html', 'w+') as f: | |
| # f.write(str(html1)) | |
| # return token_importance, html1 | |