| """Wrapper of AllenNLP model. Fixes errors based on model predictions""" |
| import logging |
| import os |
| import sys |
| from time import time |
|
|
| import torch |
| from allennlp.data.dataset import Batch |
| from allennlp.data.fields import TextField |
| from allennlp.data.instance import Instance |
| from allennlp.data.tokenizers import Token |
| from allennlp.data.vocabulary import Vocabulary |
| from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder |
| from allennlp.nn import util |
|
|
| from gector.bert_token_embedder import PretrainedBertEmbedder |
| from gector.seq2labels_model import Seq2Labels |
| from gector.tokenizer_indexer import PretrainedBertIndexer |
| from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN |
| from utils.helpers import get_weights_name |
|
|
| logging.getLogger("werkzeug").setLevel(logging.ERROR) |
| logger = logging.getLogger(__file__) |
|
|
|
|
| class GecBERTModel(object): |
| def __init__(self, vocab_path=None, model_paths=None, |
| weigths=None, |
| max_len=50, |
| min_len=3, |
| lowercase_tokens=False, |
| log=False, |
| iterations=3, |
| model_name='roberta', |
| special_tokens_fix=1, |
| is_ensemble=True, |
| min_error_probability=0.0, |
| confidence=0, |
| del_confidence=0, |
| resolve_cycles=False, |
| ): |
| self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) |
| self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| self.max_len = max_len |
| self.min_len = min_len |
| self.lowercase_tokens = lowercase_tokens |
| self.min_error_probability = min_error_probability |
| self.vocab = Vocabulary.from_files(vocab_path) |
| self.log = log |
| self.iterations = iterations |
| self.confidence = confidence |
| self.del_conf = del_confidence |
| self.resolve_cycles = resolve_cycles |
| |
|
|
| self.indexers = [] |
| self.models = [] |
| for model_path in model_paths: |
| if is_ensemble: |
| model_name, special_tokens_fix = self._get_model_data(model_path) |
| weights_name = get_weights_name(model_name, lowercase_tokens) |
| self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) |
| model = Seq2Labels(vocab=self.vocab, |
| text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), |
| confidence=self.confidence, |
| del_confidence=self.del_conf, |
| ).to(self.device) |
| if torch.cuda.is_available(): |
| model.load_state_dict(torch.load(model_path), strict=False) |
| else: |
| model.load_state_dict(torch.load(model_path, |
| map_location=torch.device('cpu')), |
| strict=False) |
| model.eval() |
| self.models.append(model) |
|
|
| @staticmethod |
| def _get_model_data(model_path): |
| model_name = model_path.split('/')[-1] |
| tr_model, stf = model_name.split('_')[:2] |
| return tr_model, int(stf) |
|
|
| def _restore_model(self, input_path): |
| if os.path.isdir(input_path): |
| print("Model could not be restored from directory", file=sys.stderr) |
| filenames = [] |
| else: |
| filenames = [input_path] |
| for model_path in filenames: |
| try: |
| if torch.cuda.is_available(): |
| loaded_model = torch.load(model_path) |
| else: |
| loaded_model = torch.load(model_path, |
| map_location=lambda storage, |
| loc: storage) |
| except: |
| print(f"{model_path} is not valid model", file=sys.stderr) |
| own_state = self.model.state_dict() |
| for name, weights in loaded_model.items(): |
| if name not in own_state: |
| continue |
| try: |
| if len(filenames) == 1: |
| own_state[name].copy_(weights) |
| else: |
| own_state[name] += weights |
| except RuntimeError: |
| continue |
| print("Model is restored", file=sys.stderr) |
|
|
| def predict(self, batches): |
| t11 = time() |
| predictions = [] |
| for batch, model in zip(batches, self.models): |
| batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) |
| with torch.no_grad(): |
| prediction = model.forward(**batch) |
| predictions.append(prediction) |
|
|
| preds, idx, error_probs = self._convert(predictions) |
| t55 = time() |
| if self.log: |
| print(f"Inference time {t55 - t11}") |
| return preds, idx, error_probs |
|
|
| def get_token_action(self, token, index, prob, sugg_token): |
| """Get lost of suggested actions for token.""" |
| |
| if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']: |
| return None |
|
|
| if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': |
| start_pos = index |
| end_pos = index + 1 |
| elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): |
| start_pos = index + 1 |
| end_pos = index + 1 |
|
|
| if sugg_token == "$DELETE": |
| sugg_token_clear = "" |
| elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): |
| sugg_token_clear = sugg_token[:] |
| else: |
| sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] |
|
|
| return start_pos - 1, end_pos - 1, sugg_token_clear, prob |
|
|
| def _get_embbeder(self, weigths_name, special_tokens_fix): |
| embedders = {'bert': PretrainedBertEmbedder( |
| pretrained_model=weigths_name, |
| requires_grad=False, |
| top_layer_only=True, |
| special_tokens_fix=special_tokens_fix) |
| } |
| text_field_embedder = BasicTextFieldEmbedder( |
| token_embedders=embedders, |
| embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, |
| allow_unmatched_keys=True) |
| return text_field_embedder |
|
|
| def _get_indexer(self, weights_name, special_tokens_fix): |
| bert_token_indexer = PretrainedBertIndexer( |
| pretrained_model=weights_name, |
| do_lowercase=self.lowercase_tokens, |
| max_pieces_per_token=5, |
| special_tokens_fix=special_tokens_fix |
| ) |
| return {'bert': bert_token_indexer} |
|
|
| def preprocess(self, token_batch): |
| seq_lens = [len(sequence) for sequence in token_batch if sequence] |
| if not seq_lens: |
| return [] |
| max_len = min(max(seq_lens), self.max_len) |
| batches = [] |
| for indexer in self.indexers: |
| batch = [] |
| for sequence in token_batch: |
| tokens = sequence[:max_len] |
| tokens = [Token(token) for token in ['$START'] + tokens] |
| batch.append(Instance({'tokens': TextField(tokens, indexer)})) |
| batch = Batch(batch) |
| batch.index_instances(self.vocab) |
| batches.append(batch) |
|
|
| return batches |
|
|
| def _convert(self, data): |
| all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) |
| error_probs = torch.zeros_like(data[0]['max_error_probability']) |
| for output, weight in zip(data, self.model_weights): |
| all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) |
| error_probs += weight * output['max_error_probability'] / sum(self.model_weights) |
|
|
| max_vals = torch.max(all_class_probs, dim=-1) |
| probs = max_vals[0].tolist() |
| idx = max_vals[1].tolist() |
| return probs, idx, error_probs.tolist() |
|
|
| def update_final_batch(self, final_batch, pred_ids, pred_batch, |
| prev_preds_dict): |
| new_pred_ids = [] |
| total_updated = 0 |
| for i, orig_id in enumerate(pred_ids): |
| orig = final_batch[orig_id] |
| pred = pred_batch[i] |
| prev_preds = prev_preds_dict[orig_id] |
| if orig != pred and pred not in prev_preds: |
| final_batch[orig_id] = pred |
| new_pred_ids.append(orig_id) |
| prev_preds_dict[orig_id].append(pred) |
| total_updated += 1 |
| elif orig != pred and pred in prev_preds: |
| |
| final_batch[orig_id] = pred |
| total_updated += 1 |
| else: |
| continue |
| return final_batch, new_pred_ids, total_updated |
|
|
| def postprocess_batch(self, batch, all_probabilities, all_idxs, |
| error_probs): |
| all_results = [] |
| noop_index = self.vocab.get_token_index("$KEEP", "labels") |
| for tokens, probabilities, idxs, error_prob in zip(batch, |
| all_probabilities, |
| all_idxs, |
| error_probs): |
| length = min(len(tokens), self.max_len) |
| edits = [] |
|
|
| |
| if max(idxs) == 0: |
| all_results.append(tokens) |
| continue |
|
|
| |
| if error_prob < self.min_error_probability: |
| all_results.append(tokens) |
| continue |
|
|
| for i in range(length + 1): |
| |
| if i == 0: |
| token = START_TOKEN |
| else: |
| token = tokens[i - 1] |
| |
| if idxs[i] == noop_index: |
| continue |
|
|
| sugg_token = self.vocab.get_token_from_index(idxs[i], |
| namespace='labels') |
| action = self.get_token_action(token, i, probabilities[i], |
| sugg_token) |
| if not action: |
| continue |
|
|
| edits.append(action) |
| all_results.append(get_target_sent_by_edits(tokens, edits)) |
| return all_results |
|
|
| def handle_batch(self, full_batch): |
| """ |
| Handle batch of requests. |
| """ |
| final_batch = full_batch[:] |
| batch_size = len(full_batch) |
| prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} |
| short_ids = [i for i in range(len(full_batch)) |
| if len(full_batch[i]) < self.min_len] |
| pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] |
| total_updates = 0 |
|
|
| for n_iter in range(self.iterations): |
| orig_batch = [final_batch[i] for i in pred_ids] |
|
|
| sequences = self.preprocess(orig_batch) |
|
|
| if not sequences: |
| break |
| probabilities, idxs, error_probs = self.predict(sequences) |
|
|
| pred_batch = self.postprocess_batch(orig_batch, probabilities, |
| idxs, error_probs) |
| if self.log: |
| print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") |
|
|
| final_batch, pred_ids, cnt = \ |
| self.update_final_batch(final_batch, pred_ids, pred_batch, |
| prev_preds_dict) |
| total_updates += cnt |
|
|
| if not pred_ids: |
| break |
|
|
| return final_batch, total_updates |
|
|