Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| sys.path.append('BERT') | |
| from transformers import BertTokenizer | |
| from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
| from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification | |
| from transformers import AutoTokenizer | |
| from captum.attr import visualization | |
| import torch | |
| from sequenceoutput.modeling_output import SequenceClassifierOutput | |
| model = BertForSequenceClassification.from_pretrained("./BERT/BERT_weight") | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("./BERT/BERT_weight") | |
| # initialize the explanations generator | |
| explanations = Generator(model) | |
| classifications = ["NEGATIVE", "POSITIVE"] | |
| true_class = 1 | |
| def generate_visual(text_batch, target_class): | |
| encoding = tokenizer(text_batch, return_tensors='pt') | |
| input_ids = encoding['input_ids'] | |
| attention_mask = encoding['attention_mask'] | |
| expl = \ | |
| explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=11, | |
| index=target_class)[0] | |
| expl = (expl - expl.min()) / (expl.max() - expl.min()) | |
| output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1) | |
| classification = output.argmax(dim=-1).item() | |
| class_name = classifications[target_class] | |
| if class_name == "NEGATIVE": | |
| expl *= (-1) | |
| token_importance = {} | |
| tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten()) | |
| for i in range(len(tokens)): | |
| token_importance[tokens[i]] = round(expl[i].item(), 3) | |
| vis_data_records = [visualization.VisualizationDataRecord( | |
| expl, | |
| output[0][classification], | |
| classification, | |
| true_class, | |
| true_class, | |
| 1, | |
| tokens, | |
| 1)] | |
| html_page = visualization.visualize_text(vis_data_records) | |
| return token_importance, html_page.data | |