Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import AutoTokenizer | |
| from captum.attr import LayerIntegratedGradients | |
| from captum.attr import visualization | |
| from util import visualize_text | |
| classifications = ["NEGATIVE", "POSITIVE"] | |
| class IntegratedGradientsExplainer: | |
| def __init__(self): | |
| self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| self.model = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2") | |
| self.ref_token_id = self.tokenizer.unk_token_id | |
| def tokens_from_ids(self, ids): | |
| return list(map(lambda s: s[1:] if s[0] == "Ġ" else s, self.tokenizer.convert_ids_to_tokens(ids))) | |
| def custom_forward(self, inputs, attention_mask=None, pos=0): | |
| result = self.model(inputs, attention_mask=attention_mask, return_dict=True) | |
| preds = result.logits | |
| return preds | |
| def summarize_attributions(attributions): | |
| attributions = attributions.sum(dim=-1).squeeze(0) | |
| attributions = attributions / torch.norm(attributions) | |
| return attributions | |
| def run_attribution_model(self, input_ids, attention_mask, index=None, layer=None, steps=20): | |
| try: | |
| output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0] | |
| if index is None: | |
| index = output.argmax(axis=-1).item() | |
| ablator = LayerIntegratedGradients(self.custom_forward, layer) | |
| input_tensor = input_ids | |
| attention_mask = attention_mask | |
| attributions = ablator.attribute( | |
| inputs=input_ids, | |
| baselines=self.ref_token_id, | |
| additional_forward_args=(attention_mask), | |
| target=index, | |
| n_steps=steps, | |
| ) | |
| return self.summarize_attributions(attributions).unsqueeze_(0), output, index | |
| finally: | |
| pass | |
| def build_visualization(self, input_ids, attention_mask, **kwargs): | |
| vis_data_records = [] | |
| attributions, output, index = self.run_attribution_model(input_ids, attention_mask, **kwargs) | |
| for record in range(input_ids.size(0)): | |
| classification = output[record].argmax(dim=-1).item() | |
| class_name = classifications[classification] | |
| attr = attributions[record] | |
| tokens = self.tokens_from_ids(input_ids[record].flatten())[ | |
| 1 : 0 - ((attention_mask[record] == 0).sum().item() + 1) | |
| ] | |
| vis_data_records.append( | |
| visualization.VisualizationDataRecord( | |
| attr, | |
| output[record][classification], | |
| classification, | |
| classification, | |
| index, | |
| 1, | |
| tokens, | |
| 1, | |
| ) | |
| ) | |
| return visualize_text(vis_data_records) | |
| def __call__(self, input_text, layer): | |
| text_batch = [input_text] | |
| encoding = self.tokenizer(text_batch, return_tensors="pt") | |
| input_ids = encoding["input_ids"].to(self.device) | |
| attention_mask = encoding["attention_mask"].to(self.device) | |
| layer = int(layer) | |
| if layer == 0: | |
| layer = self.model.roberta.embeddings | |
| else: | |
| layer = getattr(self.model.roberta.encoder.layer, str(layer-1)) | |
| return self.build_visualization(input_ids, attention_mask, layer=layer) | |