| import torch | |
| def normalize_abbreviations(text): | |
| text = text.replace(" n't ", "n't ") | |
| text = text.replace(" N'T ", "N'T ") | |
| text = text.replace(" 'll ", "'ll ") | |
| text = text.replace(" 'LL ", "'LL ") | |
| text = text.replace(" 're ", "'re ") | |
| text = text.replace(" 'RE ", "'RE ") | |
| text = text.replace(" 've ", "'ve ") | |
| text = text.replace(" 'VE ", "'VE ") | |
| text = text.replace(" 'm ", "'m ") | |
| text = text.replace(" 'M ", "'M ") | |
| text = text.replace(" 's ", "'s ") | |
| text = text.replace(" 'S ", "'S ") | |
| text = text.replace(" 'd ", "'d ") | |
| text = text.replace(" 'D ", "'D ") | |
| return text | |
| def fix_quotes(text, quote_symbol='"'): | |
| n_quotes = text.count(f" {quote_symbol}") + text.count(f"{quote_symbol} ") - text.count(f" {quote_symbol} ") | |
| if ( | |
| n_quotes == 0 | |
| or (n_quotes % 2) == 1 | |
| or f"{quote_symbol}{quote_symbol}" in text | |
| or f"{quote_symbol} {quote_symbol}" in text | |
| ): | |
| return text | |
| i, i_quote, n_changes = 0, 0, 0 | |
| while i < len(text): | |
| if text[i] != quote_symbol or (i - 1 >= 0 and text[i - 1] != ' ' and i + 1 < len(text) and text[i + 1] != ' '): | |
| i += 1 | |
| continue | |
| if (i_quote % 2) == 0: | |
| if i > 0 and text[i - 1] != ' ': | |
| text = text[:i] + ' ' + text[i:] | |
| i += 1 | |
| n_changes += 1 | |
| if i + 1 < len(text) and text[i + 1] == ' ': | |
| text = text[:i + 1] + text[i + 2:] | |
| n_changes += 1 | |
| else: | |
| if i > 0 and text[i - 1] == ' ': | |
| text = text[:i - 1] + text[i:] | |
| i -= 1 | |
| n_changes += 1 | |
| if i + 1 < len(text) and text[i + 1].isalnum(): | |
| text = text[:i + 1] + ' ' + text[i + 1:] | |
| n_changes += 1 | |
| i_quote += 1 | |
| i += 1 | |
| return text | |
| def detokenize(tokens, compact_dashes=False): | |
| text = ' '.join(tokens) | |
| text = normalize_abbreviations(text) | |
| if compact_dashes: | |
| text = text.replace(' - ', '-') | |
| for i in range(len(text) - 2, -1, -1): | |
| if text[i] == '.' and (text[i + 1].isupper() or text[i + 1] in ['β', '(', '[', '{']): | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif text[i] in ['?', '!', 'β¦', 'β'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']): | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif text[i] == ',' and (text[i + 1].isalpha() or text[i + 1] in ['β', '(', '[', '{']): | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif text[i] in [';', ')', ']', '}', '%'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']): | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif text[i] == ':' and (text[i + 1] in ['β', '(', '[', '{'] or (text[i + 1].isalnum() and (not text[i + 1].isnumeric() or i - 1 < 0 or not text[i - 1].isnumeric()))): | |
| text = text[:i+1] + ' ' + text[i+1:] | |
| elif text[i] in ['(', '[', '{'] and text[i + 1] == ' ': | |
| text = text[:i+1] + text[i+2:] | |
| elif text[i] == ' ' and text[i+1] in ['.', ';', ':', '?', '!', 'β¦', ',', 'β', ')', ']']: | |
| text = text[:i] + text[i+1:] | |
| elif i > 0 and text[i] == ' ' and text[i - 1] in ['$', 'Β£', 'β¬'] and text[i + 1].isnumeric(): | |
| text = text[:i] + text[i+1:] | |
| elif i > 0 and text[i] == ' ' and text[i - 1].isnumeric() and text[i + 1] == '%': | |
| text = text[:i] + text[i+1:] | |
| text = fix_quotes(text, '"') | |
| text = fix_quotes(text, "'") | |
| spans = [] | |
| word_offset, char_offset = 0, 0 | |
| for i, ch in enumerate(text): | |
| if ch == ' ': | |
| if tokens[word_offset][char_offset] == ' ': | |
| char_offset += 1 | |
| continue | |
| assert ch == tokens[word_offset][char_offset], f"{text}\n{' '.join(tokens)}\n{tokens[word_offset]}\n{char_offset} {ch}" | |
| if char_offset == 0: | |
| start = i | |
| if char_offset == len(tokens[word_offset]) - 1: | |
| end = i + 1 | |
| spans.append((start, end)) | |
| word_offset += 1 | |
| char_offset = 0 | |
| else: | |
| char_offset += 1 | |
| return text, spans | |
| def calculate_spans(original_spans, encoding_offsets): | |
| span_id = 0 | |
| subword_spans = [[] for _ in original_spans] | |
| for i, (_, end) in enumerate(encoding_offsets): | |
| subword_spans[span_id].append(i + 1) | |
| while original_spans[span_id][1] <= end: | |
| span_id += 1 | |
| if span_id < len(original_spans) and end > original_spans[span_id][0]: | |
| subword_spans[span_id].append(i + 1) | |
| if span_id == len(original_spans): | |
| return subword_spans | |
| return subword_spans | |
| def subtokenize(tokens, tokenizer, compact_dashes=False): | |
| text, spans = detokenize(tokens, compact_dashes=compact_dashes) | |
| encoding = tokenizer(text, return_offsets_mapping=True) | |
| spans = calculate_spans(spans, encoding["offset_mapping"][1:-1]) | |
| subwords = encoding["input_ids"] | |
| subword_mask = torch.zeros(len(subwords), len(spans), dtype=torch.bool) | |
| for word_id, subword_ids in enumerate(spans): | |
| for subword_id in subword_ids: | |
| subword_mask[subword_id + 1, word_id] = True | |
| return subwords, subword_mask | |