comparative-explainability / Transformer-Explainability /BERT_rationale_benchmark /models /pipeline /bert_pipeline.py
| # TODO consider if this can be collapsed back down into the pipeline_train.py | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from collections import OrderedDict | |
| from itertools import chain | |
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from BERT_explainability.modules.BERT.BERT_cls_lrp import \ | |
| BertForSequenceClassification as BertForClsOrigLrp | |
| from BERT_explainability.modules.BERT.BertForSequenceClassification import \ | |
| BertForSequenceClassification as BertForSequenceClassificationTest | |
| from BERT_explainability.modules.BERT.ExplanationGenerator import Generator | |
| from BERT_rationale_benchmark.utils import (Annotation, Evidence, | |
| load_datasets, load_documents, | |
| write_jsonl) | |
| from sklearn.metrics import accuracy_score | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| logging.basicConfig( | |
| level=logging.DEBUG, format="%(relativeCreated)6d %(threadName)s %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # let's make this more or less deterministic (not resistent to restarts) | |
| random.seed(12345) | |
| np.random.seed(67890) | |
| torch.manual_seed(10111213) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| import numpy as np | |
| latex_special_token = ["!@#$%^&*()"] | |
| def generate(text_list, attention_list, latex_file, color="red", rescale_value=False): | |
| attention_list = attention_list[: len(text_list)] | |
| if attention_list.max() == attention_list.min(): | |
| attention_list = torch.zeros_like(attention_list) | |
| else: | |
| attention_list = ( | |
| 100 | |
| * (attention_list - attention_list.min()) | |
| / (attention_list.max() - attention_list.min()) | |
| ) | |
| attention_list[attention_list < 1] = 0 | |
| attention_list = attention_list.tolist() | |
| text_list = [text_list[i].replace("$", "") for i in range(len(text_list))] | |
| if rescale_value: | |
| attention_list = rescale(attention_list) | |
| word_num = len(text_list) | |
| text_list = clean_word(text_list) | |
| with open(latex_file, "w") as f: | |
| f.write( | |
| r"""\documentclass[varwidth=150mm]{standalone} | |
| \special{papersize=210mm,297mm} | |
| \usepackage{color} | |
| \usepackage{tcolorbox} | |
| \usepackage{CJK} | |
| \usepackage{adjustbox} | |
| \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} | |
| \begin{document} | |
| \begin{CJK*}{UTF8}{gbsn}""" | |
| + "\n" | |
| ) | |
| string = ( | |
| r"""{\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{""" | |
| + "\n" | |
| ) | |
| for idx in range(word_num): | |
| # string += "\\colorbox{%s!%s}{"%(color, attention_list[idx])+"\\strut " + text_list[idx]+"} " | |
| # print(text_list[idx]) | |
| if "\#\#" in text_list[idx]: | |
| token = text_list[idx].replace("\#\#", "") | |
| string += ( | |
| "\\colorbox{%s!%s}{" % (color, attention_list[idx]) | |
| + "\\strut " | |
| + token | |
| + "}" | |
| ) | |
| else: | |
| string += ( | |
| " " | |
| + "\\colorbox{%s!%s}{" % (color, attention_list[idx]) | |
| + "\\strut " | |
| + text_list[idx] | |
| + "}" | |
| ) | |
| string += "\n}}}" | |
| f.write(string + "\n") | |
| f.write( | |
| r"""\end{CJK*} | |
| \end{document}""" | |
| ) | |
| def clean_word(word_list): | |
| new_word_list = [] | |
| for word in word_list: | |
| for latex_sensitive in ["\\", "%", "&", "^", "#", "_", "{", "}"]: | |
| if latex_sensitive in word: | |
| word = word.replace(latex_sensitive, "\\" + latex_sensitive) | |
| new_word_list.append(word) | |
| return new_word_list | |
| def scores_per_word_from_scores_per_token(input, tokenizer, input_ids, scores_per_id): | |
| words = tokenizer.convert_ids_to_tokens(input_ids) | |
| words = [word.replace("##", "") for word in words] | |
| score_per_char = [] | |
| # TODO: DELETE | |
| input_ids_chars = [] | |
| for word in words: | |
| if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: | |
| continue | |
| input_ids_chars += list(word) | |
| # TODO: DELETE | |
| for i in range(len(scores_per_id)): | |
| if words[i] in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: | |
| continue | |
| score_per_char += [scores_per_id[i]] * len(words[i]) | |
| score_per_word = [] | |
| start_idx = 0 | |
| end_idx = 0 | |
| # TODO: DELETE | |
| words_from_chars = [] | |
| for inp in input: | |
| if start_idx >= len(score_per_char): | |
| break | |
| end_idx = end_idx + len(inp) | |
| score_per_word.append(np.max(score_per_char[start_idx:end_idx])) | |
| # TODO: DELETE | |
| words_from_chars.append("".join(input_ids_chars[start_idx:end_idx])) | |
| start_idx = end_idx | |
| if words_from_chars[:-1] != input[: len(words_from_chars) - 1]: | |
| print(words_from_chars) | |
| print(input[: len(words_from_chars)]) | |
| print(words) | |
| print(tokenizer.convert_ids_to_tokens(input_ids)) | |
| assert False | |
| return torch.tensor(score_per_word) | |
| def get_input_words(input, tokenizer, input_ids): | |
| words = tokenizer.convert_ids_to_tokens(input_ids) | |
| words = [word.replace("##", "") for word in words] | |
| input_ids_chars = [] | |
| for word in words: | |
| if word in ["[CLS]", "[SEP]", "[UNK]", "[PAD]"]: | |
| continue | |
| input_ids_chars += list(word) | |
| start_idx = 0 | |
| end_idx = 0 | |
| words_from_chars = [] | |
| for inp in input: | |
| if start_idx >= len(input_ids_chars): | |
| break | |
| end_idx = end_idx + len(inp) | |
| words_from_chars.append("".join(input_ids_chars[start_idx:end_idx])) | |
| start_idx = end_idx | |
| if words_from_chars[:-1] != input[: len(words_from_chars) - 1]: | |
| print(words_from_chars) | |
| print(input[: len(words_from_chars)]) | |
| print(words) | |
| print(tokenizer.convert_ids_to_tokens(input_ids)) | |
| assert False | |
| return words_from_chars | |
| def bert_tokenize_doc( | |
| doc: List[List[str]], tokenizer, special_token_map | |
| ) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]: | |
| """Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" | |
| sents = [] | |
| sent_token_spans = [] | |
| for sent in doc: | |
| tokens = [] | |
| spans = [] | |
| start = 0 | |
| for w in sent: | |
| if w in special_token_map: | |
| tokens.append(w) | |
| else: | |
| tokens.extend(tokenizer.tokenize(w)) | |
| end = len(tokens) | |
| spans.append((start, end)) | |
| start = end | |
| sents.append(tokens) | |
| sent_token_spans.append(spans) | |
| return sents, sent_token_spans | |
| def initialize_models(params: dict, batch_first: bool, use_half_precision=False): | |
| assert batch_first | |
| max_length = params["max_length"] | |
| tokenizer = BertTokenizer.from_pretrained(params["bert_vocab"]) | |
| pad_token_id = tokenizer.pad_token_id | |
| cls_token_id = tokenizer.cls_token_id | |
| sep_token_id = tokenizer.sep_token_id | |
| bert_dir = params["bert_dir"] | |
| evidence_classes = dict( | |
| (y, x) for (x, y) in enumerate(params["evidence_classifier"]["classes"]) | |
| ) | |
| evidence_classifier = BertForSequenceClassification.from_pretrained( | |
| bert_dir, num_labels=len(evidence_classes) | |
| ) | |
| word_interner = tokenizer.vocab | |
| de_interner = tokenizer.ids_to_tokens | |
| return evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer | |
| BATCH_FIRST = True | |
| def extract_docid_from_dataset_element(element): | |
| return next(iter(element.evidences))[0].docid | |
| def extract_evidence_from_dataset_element(element): | |
| return next(iter(element.evidences)) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="""Trains a pipeline model. | |
| Step 1 is evidence identification, that is identify if a given sentence is evidence or not | |
| Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task | |
| (e.g. sentiment or significance). | |
| These models should be separated into two separate steps, but at the moment: | |
| * prep data (load, intern documents, load json) | |
| * convert data for evidence identification - in the case of training data we take all the positives and sample some | |
| negatives | |
| * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a | |
| broader sampling of negative values. | |
| * train evidence identification | |
| * convert data for evidence classification - take all rationales + decisions and use this as input | |
| * train evidence classification | |
| * decode first the evidence, then run classification for each split | |
| """, | |
| formatter_class=argparse.RawTextHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--data_dir", | |
| dest="data_dir", | |
| required=True, | |
| help="Which directory contains a {train,val,test}.jsonl file?", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| dest="output_dir", | |
| required=True, | |
| help="Where shall we write intermediate models + final data to?", | |
| ) | |
| parser.add_argument( | |
| "--model_params", | |
| dest="model_params", | |
| required=True, | |
| help="JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.", | |
| ) | |
| args = parser.parse_args() | |
| assert BATCH_FIRST | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| with open(args.model_params, "r") as fp: | |
| logger.info(f"Loading model parameters from {args.model_params}") | |
| model_params = json.load(fp) | |
| logger.info(f"Params: {json.dumps(model_params, indent=2, sort_keys=True)}") | |
| train, val, test = load_datasets(args.data_dir) | |
| docids = set( | |
| e.docid | |
| for e in chain.from_iterable( | |
| chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))) | |
| ) | |
| ) | |
| documents = load_documents(args.data_dir, docids) | |
| logger.info(f"Loaded {len(documents)} documents") | |
| ( | |
| evidence_classifier, | |
| word_interner, | |
| de_interner, | |
| evidence_classes, | |
| tokenizer, | |
| ) = initialize_models(model_params, batch_first=BATCH_FIRST) | |
| logger.info(f"We have {len(word_interner)} wordpieces") | |
| cache = os.path.join(args.output_dir, "preprocessed.pkl") | |
| if os.path.exists(cache): | |
| logger.info(f"Loading interned documents from {cache}") | |
| (interned_documents) = torch.load(cache) | |
| else: | |
| logger.info(f"Interning documents") | |
| interned_documents = {} | |
| for d, doc in documents.items(): | |
| encoding = tokenizer.encode_plus( | |
| doc, | |
| add_special_tokens=True, | |
| max_length=model_params["max_length"], | |
| return_token_type_ids=False, | |
| pad_to_max_length=False, | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| truncation=True, | |
| ) | |
| interned_documents[d] = encoding | |
| torch.save((interned_documents), cache) | |
| evidence_classifier = evidence_classifier.cuda() | |
| optimizer = None | |
| scheduler = None | |
| save_dir = args.output_dir | |
| logging.info(f"Beginning training classifier") | |
| evidence_classifier_output_dir = os.path.join(save_dir, "classifier") | |
| os.makedirs(save_dir, exist_ok=True) | |
| os.makedirs(evidence_classifier_output_dir, exist_ok=True) | |
| model_save_file = os.path.join(evidence_classifier_output_dir, "classifier.pt") | |
| epoch_save_file = os.path.join( | |
| evidence_classifier_output_dir, "classifier_epoch_data.pt" | |
| ) | |
| device = next(evidence_classifier.parameters()).device | |
| if optimizer is None: | |
| optimizer = torch.optim.Adam( | |
| evidence_classifier.parameters(), | |
| lr=model_params["evidence_classifier"]["lr"], | |
| ) | |
| criterion = nn.CrossEntropyLoss(reduction="none") | |
| batch_size = model_params["evidence_classifier"]["batch_size"] | |
| epochs = model_params["evidence_classifier"]["epochs"] | |
| patience = model_params["evidence_classifier"]["patience"] | |
| max_grad_norm = model_params["evidence_classifier"].get("max_grad_norm", None) | |
| class_labels = [k for k, v in sorted(evidence_classes.items())] | |
| results = { | |
| "train_loss": [], | |
| "train_f1": [], | |
| "train_acc": [], | |
| "val_loss": [], | |
| "val_f1": [], | |
| "val_acc": [], | |
| } | |
| best_epoch = -1 | |
| best_val_acc = 0 | |
| best_val_loss = float("inf") | |
| best_model_state_dict = None | |
| start_epoch = 0 | |
| epoch_data = {} | |
| if os.path.exists(epoch_save_file): | |
| logging.info(f"Restoring model from {model_save_file}") | |
| evidence_classifier.load_state_dict(torch.load(model_save_file)) | |
| epoch_data = torch.load(epoch_save_file) | |
| start_epoch = epoch_data["epoch"] + 1 | |
| # handle finishing because patience was exceeded or we didn't get the best final epoch | |
| if bool(epoch_data.get("done", 0)): | |
| start_epoch = epochs | |
| results = epoch_data["results"] | |
| best_epoch = start_epoch | |
| best_model_state_dict = OrderedDict( | |
| {k: v.cpu() for k, v in evidence_classifier.state_dict().items()} | |
| ) | |
| logging.info(f"Restoring training from epoch {start_epoch}") | |
| logging.info( | |
| f"Training evidence classifier from epoch {start_epoch} until epoch {epochs}" | |
| ) | |
| optimizer.zero_grad() | |
| for epoch in range(start_epoch, epochs): | |
| epoch_train_data = random.sample(train, k=len(train)) | |
| epoch_train_loss = 0 | |
| epoch_training_acc = 0 | |
| evidence_classifier.train() | |
| logging.info( | |
| f"Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples" | |
| ) | |
| for batch_start in range(0, len(epoch_train_data), batch_size): | |
| batch_elements = epoch_train_data[ | |
| batch_start : min(batch_start + batch_size, len(epoch_train_data)) | |
| ] | |
| targets = [evidence_classes[s.classification] for s in batch_elements] | |
| targets = torch.tensor(targets, dtype=torch.long, device=device) | |
| samples_encoding = [ | |
| interned_documents[extract_docid_from_dataset_element(s)] | |
| for s in batch_elements | |
| ] | |
| input_ids = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["input_ids"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| attention_masks = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["attention_mask"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| preds = evidence_classifier( | |
| input_ids=input_ids, attention_mask=attention_masks | |
| )[0] | |
| epoch_training_acc += accuracy_score( | |
| preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False | |
| ) | |
| loss = criterion(preds, targets.to(device=preds.device)).sum() | |
| epoch_train_loss += loss.item() | |
| loss.backward() | |
| assert loss == loss # for nans | |
| if max_grad_norm: | |
| torch.nn.utils.clip_grad_norm_( | |
| evidence_classifier.parameters(), max_grad_norm | |
| ) | |
| optimizer.step() | |
| if scheduler: | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| epoch_train_loss /= len(epoch_train_data) | |
| epoch_training_acc /= len(epoch_train_data) | |
| assert epoch_train_loss == epoch_train_loss # for nans | |
| results["train_loss"].append(epoch_train_loss) | |
| logging.info(f"Epoch {epoch} training loss {epoch_train_loss}") | |
| logging.info(f"Epoch {epoch} training accuracy {epoch_training_acc}") | |
| with torch.no_grad(): | |
| epoch_val_loss = 0 | |
| epoch_val_acc = 0 | |
| epoch_val_data = random.sample(val, k=len(val)) | |
| evidence_classifier.eval() | |
| val_batch_size = 32 | |
| logging.info( | |
| f"Validating with {len(epoch_val_data) // val_batch_size} batches with {len(epoch_val_data)} examples" | |
| ) | |
| for batch_start in range(0, len(epoch_val_data), val_batch_size): | |
| batch_elements = epoch_val_data[ | |
| batch_start : min(batch_start + val_batch_size, len(epoch_val_data)) | |
| ] | |
| targets = [evidence_classes[s.classification] for s in batch_elements] | |
| targets = torch.tensor(targets, dtype=torch.long, device=device) | |
| samples_encoding = [ | |
| interned_documents[extract_docid_from_dataset_element(s)] | |
| for s in batch_elements | |
| ] | |
| input_ids = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["input_ids"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| attention_masks = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["attention_mask"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| preds = evidence_classifier( | |
| input_ids=input_ids, attention_mask=attention_masks | |
| )[0] | |
| epoch_val_acc += accuracy_score( | |
| preds.argmax(dim=1).cpu(), targets.cpu(), normalize=False | |
| ) | |
| loss = criterion(preds, targets.to(device=preds.device)).sum() | |
| epoch_val_loss += loss.item() | |
| epoch_val_loss /= len(val) | |
| epoch_val_acc /= len(val) | |
| results["val_acc"].append(epoch_val_acc) | |
| results["val_loss"] = epoch_val_loss | |
| logging.info(f"Epoch {epoch} val loss {epoch_val_loss}") | |
| logging.info(f"Epoch {epoch} val acc {epoch_val_acc}") | |
| if epoch_val_acc > best_val_acc or ( | |
| epoch_val_acc == best_val_acc and epoch_val_loss < best_val_loss | |
| ): | |
| best_model_state_dict = OrderedDict( | |
| {k: v.cpu() for k, v in evidence_classifier.state_dict().items()} | |
| ) | |
| best_epoch = epoch | |
| best_val_acc = epoch_val_acc | |
| best_val_loss = epoch_val_loss | |
| epoch_data = { | |
| "epoch": epoch, | |
| "results": results, | |
| "best_val_acc": best_val_acc, | |
| "done": 0, | |
| } | |
| torch.save(evidence_classifier.state_dict(), model_save_file) | |
| torch.save(epoch_data, epoch_save_file) | |
| logging.debug( | |
| f"Epoch {epoch} new best model with val accuracy {epoch_val_acc}" | |
| ) | |
| if epoch - best_epoch > patience: | |
| logging.info(f"Exiting after epoch {epoch} due to no improvement") | |
| epoch_data["done"] = 1 | |
| torch.save(epoch_data, epoch_save_file) | |
| break | |
| epoch_data["done"] = 1 | |
| epoch_data["results"] = results | |
| torch.save(epoch_data, epoch_save_file) | |
| evidence_classifier.load_state_dict(best_model_state_dict) | |
| evidence_classifier = evidence_classifier.to(device=device) | |
| evidence_classifier.eval() | |
| # test | |
| test_classifier = BertForSequenceClassificationTest.from_pretrained( | |
| model_params["bert_dir"], num_labels=len(evidence_classes) | |
| ).to(device) | |
| orig_lrp_classifier = BertForClsOrigLrp.from_pretrained( | |
| model_params["bert_dir"], num_labels=len(evidence_classes) | |
| ).to(device) | |
| if os.path.exists(epoch_save_file): | |
| logging.info(f"Restoring model from {model_save_file}") | |
| test_classifier.load_state_dict(torch.load(model_save_file)) | |
| orig_lrp_classifier.load_state_dict(torch.load(model_save_file)) | |
| test_classifier.eval() | |
| orig_lrp_classifier.eval() | |
| test_batch_size = 1 | |
| logging.info( | |
| f"Testing with {len(test) // test_batch_size} batches with {len(test)} examples" | |
| ) | |
| # explainability | |
| explanations = Generator(test_classifier) | |
| explanations_orig_lrp = Generator(orig_lrp_classifier) | |
| method = "transformer_attribution" | |
| method_folder = { | |
| "transformer_attribution": "ours", | |
| "partial_lrp": "partial_lrp", | |
| "last_attn": "last_attn", | |
| "attn_gradcam": "attn_gradcam", | |
| "lrp": "lrp", | |
| "rollout": "rollout", | |
| "ground_truth": "ground_truth", | |
| "generate_all": "generate_all", | |
| } | |
| method_expl = { | |
| "transformer_attribution": explanations.generate_LRP, | |
| "partial_lrp": explanations_orig_lrp.generate_LRP_last_layer, | |
| "last_attn": explanations_orig_lrp.generate_attn_last_layer, | |
| "attn_gradcam": explanations_orig_lrp.generate_attn_gradcam, | |
| "lrp": explanations_orig_lrp.generate_full_lrp, | |
| "rollout": explanations_orig_lrp.generate_rollout, | |
| } | |
| os.makedirs(os.path.join(args.output_dir, method_folder[method]), exist_ok=True) | |
| result_files = [] | |
| for i in range(5, 85, 5): | |
| result_files.append( | |
| open( | |
| os.path.join( | |
| args.output_dir, "{0}/identifier_results_{1}.json" | |
| ).format(method_folder[method], i), | |
| "w", | |
| ) | |
| ) | |
| j = 0 | |
| for batch_start in range(0, len(test), test_batch_size): | |
| batch_elements = test[ | |
| batch_start : min(batch_start + test_batch_size, len(test)) | |
| ] | |
| targets = [evidence_classes[s.classification] for s in batch_elements] | |
| targets = torch.tensor(targets, dtype=torch.long, device=device) | |
| samples_encoding = [ | |
| interned_documents[extract_docid_from_dataset_element(s)] | |
| for s in batch_elements | |
| ] | |
| input_ids = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["input_ids"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| attention_masks = ( | |
| torch.stack( | |
| [ | |
| samples_encoding[i]["attention_mask"] | |
| for i in range(len(samples_encoding)) | |
| ] | |
| ) | |
| .squeeze(1) | |
| .to(device) | |
| ) | |
| preds = test_classifier( | |
| input_ids=input_ids, attention_mask=attention_masks | |
| )[0] | |
| for s in batch_elements: | |
| doc_name = extract_docid_from_dataset_element(s) | |
| inp = documents[doc_name].split() | |
| classification = "neg" if targets.item() == 0 else "pos" | |
| is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 | |
| if method == "generate_all": | |
| file_name = "{0}_{1}_{2}.tex".format( | |
| j, classification, is_classification_correct | |
| ) | |
| GT_global = os.path.join( | |
| args.output_dir, "{0}/visual_results_{1}.pdf" | |
| ).format(method_folder["ground_truth"], j) | |
| GT_ours = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["transformer_attribution"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| CF_ours = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( | |
| method_folder["transformer_attribution"], j | |
| ) | |
| GT_partial = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["partial_lrp"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| CF_partial = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( | |
| method_folder["partial_lrp"], j | |
| ) | |
| GT_gradcam = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["attn_gradcam"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| CF_gradcam = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( | |
| method_folder["attn_gradcam"], j | |
| ) | |
| GT_lrp = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["lrp"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| CF_lrp = os.path.join(args.output_dir, "{0}/{1}_CF.pdf").format( | |
| method_folder["lrp"], j | |
| ) | |
| GT_lastattn = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["last_attn"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| GT_rollout = os.path.join( | |
| args.output_dir, "{0}/{1}_GT_{2}_{3}.pdf" | |
| ).format( | |
| method_folder["rollout"], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| with open(file_name, "w") as f: | |
| f.write( | |
| r"""\documentclass[varwidth]{standalone} | |
| \usepackage{color} | |
| \usepackage{tcolorbox} | |
| \usepackage{CJK} | |
| \tcbset{width=0.9\textwidth,boxrule=0pt,colback=red,arc=0pt,auto outer arc,left=0pt,right=0pt,boxsep=5pt} | |
| \begin{document} | |
| \begin{CJK*}{UTF8}{gbsn} | |
| {\setlength{\fboxsep}{0pt}\colorbox{white!0}{\parbox{0.9\textwidth}{ | |
| \setlength{\tabcolsep}{2pt} % Default value: 6pt | |
| \begin{tabular}{ccc} | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_global | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_ours | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + CF_ours | |
| + """}\\\\ | |
| (a) & (b) & (c)\\\\ | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_partial | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + CF_partial | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_gradcam | |
| + """}\\\\ | |
| (d) & (e) & (f)\\\\ | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + CF_gradcam | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_lrp | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + CF_lrp | |
| + """}\\\\ | |
| (g) & (h) & (i)\\\\ | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_lastattn | |
| + """}& | |
| \includegraphics[width=0.32\linewidth]{""" | |
| + GT_rollout | |
| + """}&\\\\ | |
| (j) & (k)&\\\\ | |
| \end{tabular} | |
| }}} | |
| \end{CJK*} | |
| \end{document} | |
| )""" | |
| ) | |
| j += 1 | |
| break | |
| if method == "ground_truth": | |
| inp_cropped = get_input_words(inp, tokenizer, input_ids[0]) | |
| cam = torch.zeros(len(inp_cropped)) | |
| for evidence in extract_evidence_from_dataset_element(s): | |
| start_idx = evidence.start_token | |
| if start_idx >= len(cam): | |
| break | |
| end_idx = evidence.end_token | |
| cam[start_idx:end_idx] = 1 | |
| generate( | |
| inp_cropped, | |
| cam, | |
| ( | |
| os.path.join( | |
| args.output_dir, "{0}/visual_results_{1}.tex" | |
| ).format(method_folder[method], j) | |
| ), | |
| color="green", | |
| ) | |
| j = j + 1 | |
| break | |
| text = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| classification = "neg" if targets.item() == 0 else "pos" | |
| is_classification_correct = 1 if preds.argmax(dim=1) == targets else 0 | |
| target_idx = targets.item() | |
| cam_target = method_expl[method]( | |
| input_ids=input_ids, | |
| attention_mask=attention_masks, | |
| index=target_idx, | |
| )[0] | |
| cam_target = cam_target.clamp(min=0) | |
| generate( | |
| text, | |
| cam_target, | |
| ( | |
| os.path.join(args.output_dir, "{0}/{1}_GT_{2}_{3}.tex").format( | |
| method_folder[method], | |
| j, | |
| classification, | |
| is_classification_correct, | |
| ) | |
| ), | |
| ) | |
| if method in [ | |
| "transformer_attribution", | |
| "partial_lrp", | |
| "attn_gradcam", | |
| "lrp", | |
| ]: | |
| cam_false_class = method_expl[method]( | |
| input_ids=input_ids, | |
| attention_mask=attention_masks, | |
| index=1 - target_idx, | |
| )[0] | |
| cam_false_class = cam_false_class.clamp(min=0) | |
| generate( | |
| text, | |
| cam_false_class, | |
| ( | |
| os.path.join(args.output_dir, "{0}/{1}_CF.tex").format( | |
| method_folder[method], j | |
| ) | |
| ), | |
| ) | |
| cam = cam_target | |
| cam = scores_per_word_from_scores_per_token( | |
| inp, tokenizer, input_ids[0], cam | |
| ) | |
| j = j + 1 | |
| doc_name = extract_docid_from_dataset_element(s) | |
| hard_rationales = [] | |
| for res, i in enumerate(range(5, 85, 5)): | |
| print("calculating top ", i) | |
| _, indices = cam.topk(k=i) | |
| for index in indices.tolist(): | |
| hard_rationales.append( | |
| {"start_token": index, "end_token": index + 1} | |
| ) | |
| result_dict = { | |
| "annotation_id": doc_name, | |
| "rationales": [ | |
| { | |
| "docid": doc_name, | |
| "hard_rationale_predictions": hard_rationales, | |
| } | |
| ], | |
| } | |
| result_files[res].write(json.dumps(result_dict) + "\n") | |
| for i in range(len(result_files)): | |
| result_files[i].close() | |
| if __name__ == "__main__": | |
| main() | |