|
|
from collections import defaultdict |
|
|
from functools import lru_cache |
|
|
|
|
|
class DynamicDepth(): |
|
|
""" |
|
|
Implements a cache + dynamic programming to find the relative depth of every word in a subphrase given the head word for every word. |
|
|
""" |
|
|
def get_parse_depths(self, heads, start, end): |
|
|
"""Return the relative depth for every word |
|
|
|
|
|
Args: |
|
|
heads (list): List where each entry is the index of that entry's head word in the dependency parse |
|
|
start (int): starting index of the heads for the subphrase |
|
|
end (int): ending index of the heads for the subphrase |
|
|
|
|
|
Returns: |
|
|
list: Relative depth in the dependency parse for every word |
|
|
""" |
|
|
self.heads = heads[start:end] |
|
|
self.relative_heads = [h - start if h else -100 for h in self.heads] |
|
|
|
|
|
depths = [self._get_depth_recursive(h) for h in range(len(self.relative_heads))] |
|
|
|
|
|
return depths |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def _get_depth_recursive(self, index): |
|
|
"""Recursively get the depths of every index using a cache and recursion |
|
|
|
|
|
Args: |
|
|
index (int): Index of the word for which to calculate the relative depth |
|
|
|
|
|
Returns: |
|
|
int: Relative depth of the word at the index |
|
|
""" |
|
|
|
|
|
if self.relative_heads[index] >= len(self.relative_heads) or self.relative_heads[index] < 0: |
|
|
return 0 |
|
|
return self._get_depth_recursive(self.relative_heads[index]) + 1 |
|
|
|
|
|
def find_cconj_head(heads, upos, start, end): |
|
|
""" |
|
|
Finds how far each word is from the head of a span, then uses the closest CCONJ to the head as the new head |
|
|
|
|
|
If no CCONJ is present, returns None |
|
|
""" |
|
|
|
|
|
dynamicDepth = DynamicDepth() |
|
|
depth = dynamicDepth.get_parse_depths(heads, start, end) |
|
|
depth_limit = 2 |
|
|
|
|
|
|
|
|
|
|
|
cc_indexes = [i for i in range(end - start) if upos[i+start] == 'CCONJ' and depth[i] < depth_limit] |
|
|
if cc_indexes: |
|
|
return cc_indexes[0] + start |
|
|
return None |
|
|
|
|
|
def process_document(pipe, doc_id, part_id, sentences, coref_spans, sentence_speakers, use_cconj_heads=True, lang=None): |
|
|
""" |
|
|
doc_id: a string naming the document |
|
|
part_id: if the document has a particular subpart (can be blank) |
|
|
sentences: a list of list of string representing the raw text |
|
|
|
|
|
coref_spans: a list of lists |
|
|
one list per sentence |
|
|
each sentence has a list of spans, where each span is (span_index, span_start, span_end) |
|
|
the indices are relative to 0 for that particular sentence, and if the span is exactly 1 word long, span_start == span_end |
|
|
|
|
|
sentence_speakers: a list of list of string representing who said each word. can all be blank if there are no known speakers |
|
|
""" |
|
|
sentence_lens = [len(x) for x in sentences] |
|
|
if sentence_speakers is None: |
|
|
sentence_speakers = [" " for _ in sentences] |
|
|
if all(isinstance(x, list) for x in sentence_speakers): |
|
|
speaker = [y for x in sentence_speakers for y in x] |
|
|
else: |
|
|
speaker = [y for x, sent_len in zip(sentence_speakers, sentence_lens) for y in [x] * sent_len] |
|
|
|
|
|
cased_words = [y for x in sentences for y in x] |
|
|
sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
doc = pipe(sentences) |
|
|
word_total = 0 |
|
|
heads = [] |
|
|
|
|
|
deprel = [] |
|
|
for sentence in doc.sentences: |
|
|
for word in sentence.words: |
|
|
deprel.append(word.deprel) |
|
|
if word.head == 0: |
|
|
heads.append("null") |
|
|
else: |
|
|
heads.append(word.head - 1 + word_total) |
|
|
word_total += len(sentence.words) |
|
|
|
|
|
span_clusters = defaultdict(list) |
|
|
word_clusters = defaultdict(list) |
|
|
head2span = [] |
|
|
word_total = 0 |
|
|
for sent_idx, (parsed_sentence, ontonotes_words) in enumerate(zip(doc.sentences, sentences)): |
|
|
sentence_upos = [x.upos for x in parsed_sentence.words] |
|
|
sentence_heads = [x.head - 1 if x.head > 0 else None for x in parsed_sentence.words] |
|
|
for span in coref_spans[sent_idx]: |
|
|
|
|
|
|
|
|
|
|
|
span_start = span[1] + word_total |
|
|
span_end = span[2] + word_total + 1 |
|
|
candidate_head = find_cconj_head(sentence_heads, sentence_upos, span[1], span[2]+1) if use_cconj_heads else None |
|
|
if candidate_head is None: |
|
|
for candidate_head in range(span[1], span[2] + 1): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (parsed_sentence.words[candidate_head].head - 1 < span[1] or |
|
|
parsed_sentence.words[candidate_head].head - 1 > span[2]): |
|
|
break |
|
|
else: |
|
|
|
|
|
|
|
|
candidate_head = span[1] |
|
|
|
|
|
candidate_head += word_total |
|
|
span_clusters[span[0]].append((span_start, span_end)) |
|
|
word_clusters[span[0]].append(candidate_head) |
|
|
head2span.append((candidate_head, span_start, span_end)) |
|
|
word_total += len(ontonotes_words) |
|
|
span_clusters = sorted([sorted(values) for _, values in span_clusters.items()]) |
|
|
word_clusters = sorted([sorted(values) for _, values in word_clusters.items()]) |
|
|
head2span = sorted(head2span) |
|
|
|
|
|
processed = { |
|
|
"document_id": doc_id, |
|
|
"part_id": part_id, |
|
|
"cased_words": cased_words, |
|
|
"sent_id": sent_id, |
|
|
"speaker": speaker, |
|
|
|
|
|
"deprel": deprel, |
|
|
"head": heads, |
|
|
"span_clusters": span_clusters, |
|
|
"word_clusters": word_clusters, |
|
|
"head2span": head2span, |
|
|
} |
|
|
if part_id is not None: |
|
|
processed["part_id"] = part_id |
|
|
if lang is not None: |
|
|
processed["lang"] = lang |
|
|
return processed |
|
|
|