| """Wrapper of Seq2Labels model. Fixes errors based on model predictions""" |
| from collections import defaultdict |
| from difflib import SequenceMatcher |
| import logging |
| import re |
| from time import time |
| from typing import List, Union |
| import warnings |
|
|
| import torch |
| from transformers import AutoTokenizer |
| from modeling_seq2labels import Seq2LabelsModel |
| from vocabulary import Vocabulary |
| from utils import PAD, UNK, START_TOKEN, get_target_sent_by_edits |
|
|
| logging.getLogger("werkzeug").setLevel(logging.ERROR) |
| logger = logging.getLogger(__file__) |
|
|
|
|
| class GecBERTModel(torch.nn.Module): |
| def __init__( |
| self, |
| vocab_path=None, |
| model_paths=None, |
| weights=None, |
| device=None, |
| max_len=64, |
| min_len=3, |
| lowercase_tokens=False, |
| log=False, |
| iterations=3, |
| min_error_probability=0.0, |
| confidence=0, |
| resolve_cycles=False, |
| split_chunk=False, |
| chunk_size=48, |
| overlap_size=12, |
| min_words_cut=6, |
| punc_dict={':', ".", ",", "?"}, |
| ): |
| r""" |
| Args: |
| vocab_path (`str`): |
| Path to vocabulary directory. |
| model_paths (`List[str]`): |
| List of model paths. |
| weights (`int`, *Optional*, defaults to None): |
| Weights of each model. Only relevant if `is_ensemble is True`. |
| device (`int`, *Optional*, defaults to None): |
| Device to load model. If not set, device will be automatically choose. |
| max_len (`int`, defaults to 64): |
| Max sentence length to be processed (all longer will be truncated). |
| min_len (`int`, defaults to 3): |
| Min sentence length to be processed (all shorted will be returned w/o changes). |
| lowercase_tokens (`bool`, defaults to False): |
| Whether to lowercase tokens. |
| log (`bool`, defaults to False): |
| Whether to enable logging. |
| iterations (`int`, defaults to 3): |
| Max iterations to run during inference. |
| special_tokens_fix (`bool`, defaults to True): |
| Whether to fix problem with [CLS], [SEP] tokens tokenization. |
| min_error_probability (`float`, defaults to `0.0`): |
| Minimum probability for each action to apply. |
| confidence (`float`, defaults to `0.0`): |
| How many probability to add to $KEEP token. |
| split_chunk (`bool`, defaults to False): |
| Whether to split long sentences to multiple segments of `chunk_size`. |
| !Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`. |
| chunk_size (`int`, defaults to 48): |
| Length of each segment (in words). Only relevant if `split_chunk is True`. |
| overlap_size (`int`, defaults to 12): |
| Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`. |
| min_words_cut (`int`, defaults to 6): |
| Minimun number of words to be cut while merging two consecutive segments. |
| Only relevant if `split_chunk is True`. |
| punc_dict (List[str], defaults to `{':', ".", ",", "?"}`): |
| List of punctuations. |
| """ |
| super().__init__() |
| if isinstance(model_paths, str): |
| model_paths = [model_paths] |
| self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths) |
| self.device = ( |
| torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device) |
| ) |
| self.max_len = max_len |
| self.min_len = min_len |
| self.lowercase_tokens = lowercase_tokens |
| self.min_error_probability = min_error_probability |
| self.vocab = Vocabulary.from_files(vocab_path) |
| self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags") |
| self.log = log |
| self.iterations = iterations |
| self.confidence = confidence |
| self.resolve_cycles = resolve_cycles |
|
|
| assert ( |
| chunk_size > 0 and chunk_size // 2 >= overlap_size |
| ), "Chunk merging required overlap size must be smaller than half of chunk size" |
| self.split_chunk = split_chunk |
| self.chunk_size = chunk_size |
| self.overlap_size = overlap_size |
| self.min_words_cut = min_words_cut |
| self.stride = chunk_size - overlap_size |
| self.punc_dict = punc_dict |
| self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']' |
| |
|
|
| self.indexers = [] |
| self.models = [] |
| for model_path in model_paths: |
| model = Seq2LabelsModel.from_pretrained(model_path) |
| config = model.config |
| model_name = config.pretrained_name_or_path |
| special_tokens_fix = config.special_tokens_fix |
| self.indexers.append(self._get_indexer(model_name, special_tokens_fix)) |
| model.eval().to(self.device) |
| self.models.append(model) |
|
|
| def _get_indexer(self, weights_name, special_tokens_fix): |
| tokenizer = AutoTokenizer.from_pretrained( |
| weights_name, do_basic_tokenize=False, do_lower_case=self.lowercase_tokens, model_max_length=1024 |
| ) |
| |
| if hasattr(tokenizer, 'encoder'): |
| tokenizer.vocab = tokenizer.encoder |
| if hasattr(tokenizer, 'sp_model'): |
| tokenizer.vocab = defaultdict(lambda: 1) |
| for i in range(tokenizer.sp_model.get_piece_size()): |
| tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i |
|
|
| if special_tokens_fix: |
| tokenizer.add_tokens([START_TOKEN]) |
| tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1 |
| return tokenizer |
| |
| def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False): |
| |
| def _is_valid_text_input(t): |
| if isinstance(t, str): |
| |
| return True |
| elif isinstance(t, (list, tuple)): |
| |
| if len(t) == 0: |
| |
| return True |
| elif isinstance(t[0], str): |
| |
| return True |
| elif isinstance(t[0], (list, tuple)): |
| |
| return len(t[0]) == 0 or isinstance(t[0][0], str) |
| else: |
| return False |
| else: |
| return False |
|
|
| if not _is_valid_text_input(text): |
| raise ValueError( |
| "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " |
| "or `List[List[str]]` (batch of pretokenized examples)." |
| ) |
| |
| if is_split_into_words: |
| is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) |
| else: |
| is_batched = isinstance(text, (list, tuple)) |
| if is_batched: |
| text = [x.split() for x in text] |
| else: |
| text = text.split() |
| |
| if not is_batched: |
| text = [text] |
| |
| return self.handle_batch(text) |
|
|
| def split_chunks(self, batch): |
| |
| result = [] |
| indices = [] |
| for tokens in batch: |
| start = len(result) |
| num_token = len(tokens) |
| if num_token <= self.chunk_size: |
| result.append(tokens) |
| elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size): |
| split_idx = (num_token + self.overlap_size + 1) // 2 |
| result.append(tokens[:split_idx]) |
| result.append(tokens[split_idx - self.overlap_size :]) |
| else: |
| for i in range(0, num_token - self.overlap_size, self.stride): |
| result.append(tokens[i : i + self.chunk_size]) |
|
|
| indices.append((start, len(result))) |
|
|
| return result, indices |
|
|
| def check_alnum(self, s): |
| if len(s) < 2: |
| return False |
| return not (s.isalpha() or s.isdigit()) |
|
|
| def apply_chunk_merging(self, tokens, next_tokens): |
| |
| if not tokens: |
| return next_tokens |
|
|
| source_token_idx = [] |
| target_token_idx = [] |
| source_tokens = [] |
| target_tokens = [] |
| num_keep = self.overlap_size - self.min_words_cut |
| i = 0 |
| while len(source_token_idx) < self.overlap_size and -i < len(tokens): |
| i -= 1 |
| if tokens[i] not in self.punc_dict: |
| source_token_idx.insert(0, i) |
| source_tokens.insert(0, tokens[i].lower()) |
|
|
| i = 0 |
| while len(target_token_idx) < self.overlap_size and i < len(next_tokens): |
| if next_tokens[i] not in self.punc_dict: |
| target_token_idx.append(i) |
| target_tokens.append(next_tokens[i].lower()) |
| i += 1 |
|
|
| matcher = SequenceMatcher(None, source_tokens, target_tokens) |
| diffs = list(matcher.get_opcodes()) |
|
|
| for diff in diffs: |
| tag, i1, i2, j1, j2 = diff |
| if tag == "equal": |
| if i1 >= num_keep: |
| tail_idx = source_token_idx[i1] |
| head_idx = target_token_idx[j1] |
| break |
| elif i2 > num_keep: |
| tail_idx = source_token_idx[num_keep] |
| head_idx = target_token_idx[j2 - i2 + num_keep] |
| break |
| elif tag == "delete" and i1 == 0: |
| num_keep += i2 // 2 |
|
|
| tokens = tokens[:tail_idx] + next_tokens[head_idx:] |
| return tokens |
|
|
| def merge_chunks(self, batch): |
| result = [] |
| if len(batch) == 1 or self.overlap_size == 0: |
| for sub_tokens in batch: |
| result.extend(sub_tokens) |
| else: |
| for _, sub_tokens in enumerate(batch): |
| try: |
| result = self.apply_chunk_merging(result, sub_tokens) |
| except Exception as e: |
| print(e) |
|
|
| result = " ".join(result) |
| return result |
|
|
| def predict(self, batches): |
| t11 = time() |
| predictions = [] |
| for batch, model in zip(batches, self.models): |
| batch = batch.to(self.device) |
| with torch.no_grad(): |
| prediction = model.forward(**batch) |
| predictions.append(prediction) |
|
|
| preds, idx, error_probs = self._convert(predictions) |
| t55 = time() |
| if self.log: |
| print(f"Inference time {t55 - t11}") |
| return preds, idx, error_probs |
|
|
| def get_token_action(self, token, index, prob, sugg_token): |
| """Get lost of suggested actions for token.""" |
| |
| if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']: |
| return None |
|
|
| if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': |
| start_pos = index |
| end_pos = index + 1 |
| elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): |
| start_pos = index + 1 |
| end_pos = index + 1 |
|
|
| if sugg_token == "$DELETE": |
| sugg_token_clear = "" |
| elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): |
| sugg_token_clear = sugg_token[:] |
| else: |
| sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :] |
|
|
| return start_pos - 1, end_pos - 1, sugg_token_clear, prob |
|
|
| def preprocess(self, token_batch): |
| seq_lens = [len(sequence) for sequence in token_batch if sequence] |
| if not seq_lens: |
| return [] |
| max_len = min(max(seq_lens), self.max_len) |
| batches = [] |
| for indexer in self.indexers: |
| token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch] |
| batch = indexer( |
| token_batch, |
| return_tensors="pt", |
| padding=True, |
| is_split_into_words=True, |
| truncation=True, |
| add_special_tokens=False, |
| ) |
| offset_batch = [] |
| for i in range(len(token_batch)): |
| word_ids = batch.word_ids(batch_index=i) |
| offsets = [0] |
| for i in range(1, len(word_ids)): |
| if word_ids[i] != word_ids[i - 1]: |
| offsets.append(i) |
| offset_batch.append(torch.LongTensor(offsets)) |
|
|
| batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence( |
| offset_batch, batch_first=True, padding_value=0 |
| ).to(torch.long) |
|
|
| batches.append(batch) |
|
|
| return batches |
|
|
| def _convert(self, data): |
| all_class_probs = torch.zeros_like(data[0]['logits']) |
| error_probs = torch.zeros_like(data[0]['max_error_probability']) |
| for output, weight in zip(data, self.model_weights): |
| class_probabilities_labels = torch.softmax(output['logits'], dim=-1) |
| all_class_probs += weight * class_probabilities_labels / sum(self.model_weights) |
| class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1) |
| error_probs_d = class_probabilities_d[:, :, self.incorr_index] |
| incorr_prob = torch.max(error_probs_d, dim=-1)[0] |
| error_probs += weight * incorr_prob / sum(self.model_weights) |
|
|
| max_vals = torch.max(all_class_probs, dim=-1) |
| probs = max_vals[0].tolist() |
| idx = max_vals[1].tolist() |
| return probs, idx, error_probs.tolist() |
|
|
| def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict): |
| new_pred_ids = [] |
| total_updated = 0 |
| for i, orig_id in enumerate(pred_ids): |
| orig = final_batch[orig_id] |
| pred = pred_batch[i] |
| prev_preds = prev_preds_dict[orig_id] |
| if orig != pred and pred not in prev_preds: |
| final_batch[orig_id] = pred |
| new_pred_ids.append(orig_id) |
| prev_preds_dict[orig_id].append(pred) |
| total_updated += 1 |
| elif orig != pred and pred in prev_preds: |
| |
| final_batch[orig_id] = pred |
| total_updated += 1 |
| else: |
| continue |
| return final_batch, new_pred_ids, total_updated |
|
|
| def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs): |
| all_results = [] |
| noop_index = self.vocab.get_token_index("$KEEP", "labels") |
| for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs): |
| length = min(len(tokens), self.max_len) |
| edits = [] |
|
|
| |
| if max(idxs) == 0: |
| all_results.append(tokens) |
| continue |
|
|
| |
| if error_prob < self.min_error_probability: |
| all_results.append(tokens) |
| continue |
|
|
| for i in range(length + 1): |
| |
| if i == 0: |
| token = START_TOKEN |
| else: |
| token = tokens[i - 1] |
| |
| if idxs[i] == noop_index: |
| continue |
|
|
| sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels') |
| action = self.get_token_action(token, i, probabilities[i], sugg_token) |
| if not action: |
| continue |
|
|
| edits.append(action) |
| all_results.append(get_target_sent_by_edits(tokens, edits)) |
| return all_results |
|
|
| def handle_batch(self, full_batch, merge_punc=True): |
| """ |
| Handle batch of requests. |
| """ |
| if self.split_chunk: |
| full_batch, indices = self.split_chunks(full_batch) |
| else: |
| indices = None |
| final_batch = full_batch[:] |
| batch_size = len(full_batch) |
| prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} |
| short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len] |
| pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] |
| total_updates = 0 |
|
|
| for n_iter in range(self.iterations): |
| orig_batch = [final_batch[i] for i in pred_ids] |
|
|
| sequences = self.preprocess(orig_batch) |
|
|
| if not sequences: |
| break |
| probabilities, idxs, error_probs = self.predict(sequences) |
|
|
| pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs) |
| if self.log: |
| print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") |
|
|
| final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict) |
| total_updates += cnt |
|
|
| if not pred_ids: |
| break |
| if self.split_chunk: |
| final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices] |
| else: |
| final_batch = [" ".join(x) for x in final_batch] |
| if merge_punc: |
| final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch] |
|
|
| return final_batch |