File size: 7,080 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from collections import defaultdict
from functools import lru_cache
class DynamicDepth():
"""
Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word.
"""
def get_parse_depths(self, heads, start, end):
"""Return the relative depth for every word
Args:
heads (list): List where each entry is the index of that entry's head word in the dependency parse
start (int): starting index of the heads for the subphrase
end (int): ending index of the heads for the subphrase
Returns:
list: Relative depth in the dependency parse for every word
"""
self.heads = heads[start:end]
self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords
depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))]
return depths
@lru_cache(maxsize=None)
def _get_depth_recursive(self, index):
"""Recursively get the depths of every index using a cache and recursion
Args:
index (int): Index of the word for which to calculate the relative depth
Returns:
int: Relative depth of the word at the index
"""
# if the head for the current index is outside the scope, this index is a relative root
if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0:
return 0
return self._get_depth_recursive(self.relative_heads[index]) + 1
def find_cconj_head(heads, upos, start, end):
"""
Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head
If no CCONJ is present, returns None
"""
# use head information to extract parse depth
dynamicDepth = DynamicDepth()
depth = dynamicDepth.get_parse_depths(heads, start, end)
depth_limit = 2
# return first 'CCONJ' token above depth limit, if exists
# unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC
cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit]
if cc_indexes:
return cc_indexes[0] + start
return None
def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True, lang=None):
"""
doc_id: a string naming the document
part_id: if the document has a particular subpart (can be blank)
sentences: a list of list of string representing the raw text
coref_spans: a list of lists
one list per sentence
each sentence has a list of spans, where each span is (span_index, span_start, span_end)
the indices are relative to 0 for that particular sentence, and if the span is exactly 1 word long, span_start == span_end
sentence_speakers: a list of list of string representing who said each word. can all be blank if there are no known speakers
"""
sentence_lens = [len(x) for x in sentences]
if sentence_speakers is None:
sentence_speakers = [" " for _ in sentences]
if all(isinstance(x, list) for x in sentence_speakers):
speaker = [y for x in sentence_speakers for y in x]
else:
speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len]
cased_words = [y for x in sentences for y in x]
sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
# use the trees to get the xpos tags
# alternatively, could translate the pos_tags field,
# but those have numbers, which is annoying
#tree_text = "\n".join(x['parse_tree'] for x in paragraph)
#trees = tree_reader.read_trees(tree_text)
#pos = [x.label for tree in trees for x in tree.yield_preterminals()]
# actually, the downstream code doesn't use pos at all. maybe we can skip?
doc = pipe(sentences)
word_total = 0
heads = []
# TODO: does SD vs UD matter?
deprel = []
for sentence in doc.sentences:
for word in sentence.words:
deprel.append(word.deprel)
if word.head == 0:
heads.append("null")
else:
heads.append(word.head - 1 + word_total)
word_total += len(sentence.words)
span_clusters = defaultdict(list)
word_clusters = defaultdict(list)
head2span = []
word_total = 0
for sent_idx, (parsed_sentence, ontonotes_words) in enumerate(zip(doc.sentences, sentences)):
sentence_upos = [x.upos for x in parsed_sentence.words]
sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
for span in coref_spans[sent_idx]:
# input is expected to be start word, end word + 1
# counting from 0
# whereas the OntoNotes coref_span is [start_word, end_word] inclusive
span_start = span[1] + word_total
span_end = span[2] + word_total + 1
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None
if candidate_head is None:
for candidate_head in range(span[1], span[2] + 1):
# stanza uses 0 to mark the head, whereas OntoNotes is counting
# words from 0, so we have to subtract 1 from the stanza heads
#print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
# treat the head of the phrase as the first word that has a head outside the phrase
if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
parsed_sentence.words[candidate_head].head - 1 > span[2]):
break
else:
# if none have a head outside the phrase (circular??)
# then just take the first word
candidate_head = span[1]
#print("----> %d" % candidate_head)
candidate_head += word_total
span_clusters[span[0]].append((span_start, span_end))
word_clusters[span[0]].append(candidate_head)
head2span.append((candidate_head, span_start, span_end))
word_total += len(ontonotes_words)
span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
head2span = sorted(head2span)
processed = {
"document_id": doc_id,
"part_id": part_id,
"cased_words": cased_words,
"sent_id": sent_id,
"speaker": speaker,
#"pos": pos,
"deprel": deprel,
"head": heads,
"span_clusters": span_clusters,
"word_clusters": word_clusters,
"head2span": head2span,
}
if part_id is not None:
processed["part_id"] = part_id
if lang is not None:
processed["lang"] = lang
return processed
|