Spaces:
Running
Running
| import torch | |
| class DocumentState: | |
| def __init__(self): | |
| self.sentence_end = [] | |
| self.token_end = [] | |
| self.orig_tokens = [] | |
| self.tokens = [] | |
| self.subtokens = [] | |
| self.segments = [] | |
| self.subtoken_map = [] | |
| self.segment_subtoken_map = [] | |
| self.sentence_map = [] | |
| self.tensorized_sent = [] | |
| self.sent_len_list = [] | |
| def finalize(self): | |
| subtoken_map = flatten(self.segment_subtoken_map) | |
| num_words = len(flatten(self.segments)) | |
| assert num_words == len(subtoken_map), (num_words, len(subtoken_map)) | |
| return { | |
| "orig_tokens": self.orig_tokens, | |
| "sentences": self.segments, | |
| "sent_len_list": self.sent_len_list, | |
| "tensorized_sent": self.tensorized_sent, | |
| "sentence_map": torch.tensor( | |
| get_sentence_map(self.segments, self.sentence_end) | |
| ), | |
| "subtoken_map": subtoken_map, | |
| } | |
| def get_sentence_map(segments, sentence_end): | |
| current = 0 | |
| sent_map = [] | |
| sent_end_idx = 0 | |
| assert len(sentence_end) == sum([len(s) for s in segments]) | |
| for segment in segments: | |
| for i in range(len(segment)): | |
| sent_map.append(current) | |
| current += int(sentence_end[sent_end_idx]) | |
| sent_end_idx += 1 | |
| return sent_map | |
| def split_into_segments(document_state, max_segment_len, constraints1, constraints2): | |
| current = 0 | |
| while current < len(document_state.subtokens): | |
| end = min(current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1) | |
| while end >= current and not constraints1[end]: | |
| end -= 1 | |
| if end < current: | |
| end = min( | |
| current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1 | |
| ) | |
| while end >= current and not constraints2[end]: | |
| end -= 1 | |
| if end < current: | |
| raise Exception("Can't find valid segment") | |
| document_state.segments.append(document_state.subtokens[current : end + 1]) | |
| subtoken_map = document_state.subtoken_map[current : end + 1] | |
| document_state.segment_subtoken_map.append(subtoken_map) | |
| if hasattr(document_state, "info"): | |
| info = document_state.info[current : end + 1] | |
| document_state.segment_info.append(info) | |
| current = end + 1 | |
| def flatten(l): | |
| return [item for sublist in l for item in sublist] | |
| def get_tokenized_doc(doc, subword_tokenizer): | |
| document_state = DocumentState() | |
| word_idx = -1 | |
| for sentence in doc: | |
| for word in sentence: | |
| document_state.orig_tokens.append(word) | |
| subtokens = subword_tokenizer.convert_tokens_to_ids( | |
| subword_tokenizer.tokenize(" " + word) | |
| ) | |
| document_state.tokens.append(word) | |
| document_state.token_end += ([False] * (len(subtokens) - 1)) + [True] | |
| word_idx += 1 | |
| for sidx, subtoken in enumerate(subtokens): | |
| document_state.subtokens.append(subtoken) | |
| document_state.sentence_end.append(False) | |
| document_state.subtoken_map.append(word_idx) | |
| document_state.sentence_end[-1] = True | |
| return document_state | |
| def basic_tokenize_doc(doc_str, basic_tokenizer): | |
| doc = [] | |
| for sent in basic_tokenizer(doc_str).sents: | |
| wordlist = [str(word) for word in sent] | |
| doc.append(wordlist) | |
| return doc | |
| def tokenize_and_segment_doc( | |
| basic_tokenized_doc, subword_tokenizer, max_segment_len=4096 | |
| ): | |
| document_state: DocumentState = get_tokenized_doc( | |
| basic_tokenized_doc, subword_tokenizer | |
| ) | |
| document = post_tokenization_processing( | |
| document_state, subword_tokenizer, max_segment_len=max_segment_len | |
| ) | |
| return document | |
| def post_tokenization_processing( | |
| document_state: DocumentState, subword_tokenizer, max_segment_len=4096 | |
| ): | |
| split_into_segments( | |
| document_state, | |
| max_segment_len, | |
| document_state.sentence_end, | |
| document_state.token_end, | |
| ) | |
| sent_len_list = [len(sent) for sent in document_state.segments] | |
| document_state.sent_len_list = sent_len_list | |
| document_state.segments_indices = document_state.segments | |
| # # Tensorize sentence - Streaming coreference is done one window at a time, so no padding is required | |
| tensorized_sent = [ | |
| torch.unsqueeze( | |
| torch.tensor( | |
| [subword_tokenizer.cls_token_id] | |
| + sent | |
| + [subword_tokenizer.sep_token_id] | |
| ), | |
| dim=0, | |
| ) | |
| for sent in document_state.segments | |
| ] | |
| document_state.tensorized_sent = tensorized_sent | |
| return document_state.finalize() | |
| if __name__ == "__main__": | |
| from transformers import LongformerTokenizerFast | |
| tokenizer = LongformerTokenizerFast.from_pretrained( | |
| "allenai/longformer-large-4096", | |
| add_prefix_space=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| sample_doc_str = "My father’s eyes had closed upon the light of this world six months, when Ishmael opened on it." | |
| print(get_tokenized_doc(sample_doc_str, tokenizer)) | |