Spaces:
Running
Running
File size: 5,221 Bytes
98e2ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
class DocumentState:
def __init__(self):
self.sentence_end = []
self.token_end = []
self.orig_tokens = []
self.tokens = []
self.subtokens = []
self.segments = []
self.subtoken_map = []
self.segment_subtoken_map = []
self.sentence_map = []
self.tensorized_sent = []
self.sent_len_list = []
def finalize(self):
subtoken_map = flatten(self.segment_subtoken_map)
num_words = len(flatten(self.segments))
assert num_words == len(subtoken_map), (num_words, len(subtoken_map))
return {
"orig_tokens": self.orig_tokens,
"sentences": self.segments,
"sent_len_list": self.sent_len_list,
"tensorized_sent": self.tensorized_sent,
"sentence_map": torch.tensor(
get_sentence_map(self.segments, self.sentence_end)
),
"subtoken_map": subtoken_map,
}
def get_sentence_map(segments, sentence_end):
current = 0
sent_map = []
sent_end_idx = 0
assert len(sentence_end) == sum([len(s) for s in segments])
for segment in segments:
for i in range(len(segment)):
sent_map.append(current)
current += int(sentence_end[sent_end_idx])
sent_end_idx += 1
return sent_map
def split_into_segments(document_state, max_segment_len, constraints1, constraints2):
current = 0
while current < len(document_state.subtokens):
end = min(current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1)
while end >= current and not constraints1[end]:
end -= 1
if end < current:
end = min(
current + max_segment_len - 1 - 2, len(document_state.subtokens) - 1
)
while end >= current and not constraints2[end]:
end -= 1
if end < current:
raise Exception("Can't find valid segment")
document_state.segments.append(document_state.subtokens[current : end + 1])
subtoken_map = document_state.subtoken_map[current : end + 1]
document_state.segment_subtoken_map.append(subtoken_map)
if hasattr(document_state, "info"):
info = document_state.info[current : end + 1]
document_state.segment_info.append(info)
current = end + 1
def flatten(l):
return [item for sublist in l for item in sublist]
def get_tokenized_doc(doc, subword_tokenizer):
document_state = DocumentState()
word_idx = -1
for sentence in doc:
for word in sentence:
document_state.orig_tokens.append(word)
subtokens = subword_tokenizer.convert_tokens_to_ids(
subword_tokenizer.tokenize(" " + word)
)
document_state.tokens.append(word)
document_state.token_end += ([False] * (len(subtokens) - 1)) + [True]
word_idx += 1
for sidx, subtoken in enumerate(subtokens):
document_state.subtokens.append(subtoken)
document_state.sentence_end.append(False)
document_state.subtoken_map.append(word_idx)
document_state.sentence_end[-1] = True
return document_state
def basic_tokenize_doc(doc_str, basic_tokenizer):
doc = []
for sent in basic_tokenizer(doc_str).sents:
wordlist = [str(word) for word in sent]
doc.append(wordlist)
return doc
def tokenize_and_segment_doc(
basic_tokenized_doc, subword_tokenizer, max_segment_len=4096
):
document_state: DocumentState = get_tokenized_doc(
basic_tokenized_doc, subword_tokenizer
)
document = post_tokenization_processing(
document_state, subword_tokenizer, max_segment_len=max_segment_len
)
return document
def post_tokenization_processing(
document_state: DocumentState, subword_tokenizer, max_segment_len=4096
):
split_into_segments(
document_state,
max_segment_len,
document_state.sentence_end,
document_state.token_end,
)
sent_len_list = [len(sent) for sent in document_state.segments]
document_state.sent_len_list = sent_len_list
document_state.segments_indices = document_state.segments
# # Tensorize sentence - Streaming coreference is done one window at a time, so no padding is required
tensorized_sent = [
torch.unsqueeze(
torch.tensor(
[subword_tokenizer.cls_token_id]
+ sent
+ [subword_tokenizer.sep_token_id]
),
dim=0,
)
for sent in document_state.segments
]
document_state.tensorized_sent = tensorized_sent
return document_state.finalize()
if __name__ == "__main__":
from transformers import LongformerTokenizerFast
tokenizer = LongformerTokenizerFast.from_pretrained(
"allenai/longformer-large-4096",
add_prefix_space=True,
clean_up_tokenization_spaces=True,
)
sample_doc_str = "My father’s eyes had closed upon the light of this world six months, when Ishmael opened on it."
print(get_tokenized_doc(sample_doc_str, tokenizer))
|