File size: 7,080 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from collections import defaultdict
from functools import lru_cache

class DynamicDepth():
    """
    Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word.
    """
    def get_parse_depths(self, heads, start, end):
        """Return the relative depth for every word

        Args:
            heads (list): List where each entry is the index of that entry's head word in the dependency parse
            start (int): starting index of the heads for the subphrase
            end (int): ending index of the heads for the subphrase

        Returns:
            list: Relative depth in the dependency parse for every word
        """
        self.heads = heads[start:end]
        self.relative_heads = [h - start if h else -100 for h in self.heads] # -100 to deal with 'none' headwords

        depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))]

        return depths

    @lru_cache(maxsize=None)
    def _get_depth_recursive(self, index):
        """Recursively get the depths of every index using a cache and recursion

        Args:
            index (int): Index of the word for which to calculate the relative depth

        Returns:
            int: Relative depth of the word at the index
        """
        # if the head for the current index is outside the scope, this index is a relative root
        if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0:
            return 0
        return self._get_depth_recursive(self.relative_heads[index]) + 1

def find_cconj_head(heads, upos, start, end):
    """
    Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head

    If no CCONJ is present, returns None
    """
    # use head information to extract parse depth
    dynamicDepth = DynamicDepth()
    depth = dynamicDepth.get_parse_depths(heads, start, end)
    depth_limit = 2

    # return first 'CCONJ' token above depth limit, if exists
    # unlike the original paper, we expect the parses to use UPOS, hence CCONJ instead of CC
    cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit]
    if cc_indexes:
        return cc_indexes[0] + start
    return None

def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True, lang=None):
    """
    doc_id: a string naming the document
    part_id: if the document has a particular subpart (can be blank)
    sentences: a list of list of string representing the raw text

    coref_spans: a list of lists
    one list per sentence
    each sentence has a list of spans, where each span is (span_index, span_start, span_end)
    the indices are relative to 0 for that particular sentence, and if the span is exactly 1 word long, span_start == span_end

    sentence_speakers: a list of list of string representing who said each word.  can all be blank if there are no known speakers
    """
    sentence_lens = [len(x) for x in sentences]
    if sentence_speakers is None:
        sentence_speakers = [" " for _ in sentences]
    if all(isinstance(x, list) for x in sentence_speakers):
        speaker = [y for x in sentence_speakers for y in x]
    else:
        speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len]

    cased_words = [y for x in sentences for y in x]
    sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]

    # use the trees to get the xpos tags
    # alternatively, could translate the pos_tags field,
    # but those have numbers, which is annoying
    #tree_text = "\n".join(x['parse_tree'] for x in paragraph)
    #trees = tree_reader.read_trees(tree_text)
    #pos = [x.label for tree in trees for x in tree.yield_preterminals()]
    # actually, the downstream code doesn't use pos at all.  maybe we can skip?

    doc = pipe(sentences)
    word_total = 0
    heads = []
    # TODO: does SD vs UD matter?
    deprel = []
    for sentence in doc.sentences:
        for word in sentence.words:
            deprel.append(word.deprel)
            if word.head == 0:
                heads.append("null")
            else:
                heads.append(word.head - 1 + word_total)
        word_total += len(sentence.words)

    span_clusters = defaultdict(list)
    word_clusters = defaultdict(list)
    head2span = []
    word_total = 0
    for sent_idx, (parsed_sentence, ontonotes_words) in enumerate(zip(doc.sentences, sentences)):
        sentence_upos = [x.upos for x in parsed_sentence.words]
        sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words]
        for span in coref_spans[sent_idx]:
            # input is expected to be start word, end word + 1
            # counting from 0
            # whereas the OntoNotes coref_span is [start_word, end_word] inclusive
            span_start = span[1] + word_total
            span_end = span[2] + word_total + 1
            candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None
            if candidate_head is None:
                for candidate_head in range(span[1], span[2] + 1):
                    # stanza uses 0 to mark the head, whereas OntoNotes is counting
                    # words from 0, so we have to subtract 1 from the stanza heads
                    #print(span, candidate_head, parsed_sentence.words[candidate_head].head - 1)
                    # treat the head of the phrase as the first word that has a head outside the phrase
                    if (parsed_sentence.words[candidate_head].head - 1 < span[1] or
                        parsed_sentence.words[candidate_head].head - 1 > span[2]):
                        break
                else:
                    # if none have a head outside the phrase (circular??)
                    # then just take the first word
                    candidate_head = span[1]
            #print("----> %d" % candidate_head)
            candidate_head += word_total
            span_clusters[span[0]].append((span_start, span_end))
            word_clusters[span[0]].append(candidate_head)
            head2span.append((candidate_head, span_start, span_end))
        word_total += len(ontonotes_words)
    span_clusters = sorted([sorted(values) for _, values in span_clusters.items()])
    word_clusters = sorted([sorted(values) for _, values in word_clusters.items()])
    head2span = sorted(head2span)

    processed = {
        "document_id": doc_id,
        "part_id": part_id,
        "cased_words": cased_words,
        "sent_id": sent_id,
        "speaker": speaker,
        #"pos": pos,
        "deprel": deprel,
        "head": heads,
        "span_clusters": span_clusters,
        "word_clusters": word_clusters,
        "head2span": head2span,
    }
    if part_id is not None:
        processed["part_id"] = part_id
    if lang is not None:
        processed["lang"] = lang
    return processed