diff --git a/stanza/stanza/models/classifiers/__init__.py b/stanza/stanza/models/classifiers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/models/classifiers/config.py b/stanza/stanza/models/classifiers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..00e64095d8344c41928d4670c0ed6572743999b6 --- /dev/null +++ b/stanza/stanza/models/classifiers/config.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import List, Union + +# TODO: perhaps put the enums in this file +from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType + +@dataclass +class CNNConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods + filter_channels: Union[int, tuple] + filter_sizes: tuple + fc_shapes: tuple + dropout: float + num_classes: int + wordvec_type: WVType + extra_wordvec_method: ExtraVectors + extra_wordvec_dim: int + extra_wordvec_max_norm: float + char_lowercase: bool + charlm_projection: int + has_charlm_forward: bool + has_charlm_backward: bool + + use_elmo: bool + elmo_projection: int + + bert_model: str + bert_finetune: bool + bert_hidden_layers: int + force_bert_saved: bool + + use_peft: bool + lora_rank: int + lora_alpha: float + lora_dropout: float + lora_modules_to_save: List + lora_target_modules: List + + bilstm: bool + bilstm_hidden_dim: int + maxpool_width: int + model_type: ModelType + +@dataclass +class ConstituencyConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods + fc_shapes: tuple + dropout: float + num_classes: int + + constituency_backprop: bool + constituency_batch_norm: bool + constituency_node_attn: bool + constituency_top_layer: bool + constituency_all_words: bool + + model_type: ModelType diff --git a/stanza/stanza/models/classifiers/utils.py b/stanza/stanza/models/classifiers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc50486c373f8b066ca96b03420a24bbb69cbf3 --- /dev/null +++ b/stanza/stanza/models/classifiers/utils.py @@ -0,0 +1,41 @@ +from enum import Enum + +from torch import nn + +""" +Defines some methods which may occur in multiple model types +""" +# NLP machines: +# word2vec are in +# /u/nlp/data/stanfordnlp/model_production/stanfordnlp/extern_data/word2vec +# google vectors are in +# /scr/nlp/data/wordvectors/en/google/GoogleNews-vectors-negative300.txt + +class WVType(Enum): + WORD2VEC = 1 + GOOGLE = 2 + FASTTEXT = 3 + OTHER = 4 + +class ExtraVectors(Enum): + NONE = 1 + CONCAT = 2 + SUM = 3 + +class ModelType(Enum): + CNN = 1 + CONSTITUENCY = 2 + +def build_output_layers(fc_input_size, fc_shapes, num_classes): + """ + Build a sequence of fully connected layers to go from the final conv layer to num_classes + + Returns an nn.ModuleList + """ + fc_layers = [] + previous_layer_size = fc_input_size + for shape in fc_shapes: + fc_layers.append(nn.Linear(previous_layer_size, shape)) + previous_layer_size = shape + fc_layers.append(nn.Linear(previous_layer_size, num_classes)) + return nn.ModuleList(fc_layers) diff --git a/stanza/stanza/models/constituency/dynamic_oracle.py b/stanza/stanza/models/constituency/dynamic_oracle.py new file mode 100644 index 0000000000000000000000000000000000000000..27b90e0bb06bd70e1b6d01efafffc866d19a8de5 --- /dev/null +++ b/stanza/stanza/models/constituency/dynamic_oracle.py @@ -0,0 +1,135 @@ +from collections import namedtuple + +import numpy as np + +from stanza.models.constituency.parse_transitions import Shift, OpenConstituent, CloseConstituent + +RepairEnum = namedtuple("RepairEnum", "name value is_correct") + +def score_candidates(model, state, candidates, candidate_idx): + """ + score candidate fixed sequences by summing up the transition scores of the most important block + + the candidate with the best summed score is chosen, and the candidate sequence is reconstructed from the blocks + """ + scores = [] + # could bulkify this if we wanted + for candidate in candidates: + current_state = [state] + for block in candidate[1:candidate_idx]: + for transition in block: + current_state = model.bulk_apply(current_state, [transition]) + score = 0.0 + for transition in candidate[candidate_idx]: + predictions = model.forward(current_state) + t_idx = model.transition_map[transition] + score += predictions[0, t_idx].cpu().item() + current_state = model.bulk_apply(current_state, [transition]) + scores.append(score) + best_idx = np.argmax(scores) + best_candidate = [x for block in candidates[best_idx] for x in block] + return scores, best_idx, best_candidate + +def advance_past_constituents(gold_sequence, cur_index): + """ + Advance cur_index through gold_sequence until we have seen 1 more Close than Open + + The index returned is the index of the Close which occurred after all the stuff + """ + count = 0 + while cur_index < len(gold_sequence): + if isinstance(gold_sequence[cur_index], OpenConstituent): + count = count + 1 + elif isinstance(gold_sequence[cur_index], CloseConstituent): + count = count - 1 + if count == -1: return cur_index + cur_index = cur_index + 1 + return None + +def find_previous_open(gold_sequence, cur_index): + """ + Go backwards from cur_index to find the open which opens the previous block of stuff. + + Return None if it can't be found. + """ + count = 0 + cur_index = cur_index - 1 + while cur_index >= 0: + if isinstance(gold_sequence[cur_index], OpenConstituent): + count = count + 1 + if count > 0: + return cur_index + elif isinstance(gold_sequence[cur_index], CloseConstituent): + count = count - 1 + cur_index = cur_index - 1 + return None + +def find_in_order_constituent_end(gold_sequence, cur_index): + """ + Advance cur_index through gold_sequence until the next block has ended + + This is different from advance_past_constituents in that it will + also return when there is a Shift when count == 0. That way, we + return the first block of things we know attach to the left + """ + count = 0 + saw_shift = False + while cur_index < len(gold_sequence): + if isinstance(gold_sequence[cur_index], OpenConstituent): + count = count + 1 + elif isinstance(gold_sequence[cur_index], CloseConstituent): + count = count - 1 + if count == -1: return cur_index + elif isinstance(gold_sequence[cur_index], Shift): + if saw_shift and count == 0: + return cur_index + else: + saw_shift = True + cur_index = cur_index + 1 + return None + +class DynamicOracle(): + def __init__(self, root_labels, oracle_level, repair_types, additional_levels, deactivated_levels): + self.root_labels = root_labels + # default oracle_level will be the UNKNOWN repair type (which each oracle should have) + # transitions after that as experimental or ambiguous, not to be used by default + self.oracle_level = oracle_level if oracle_level is not None else repair_types.UNKNOWN.value + self.repair_types = repair_types + self.additional_levels = set() + if additional_levels: + self.additional_levels = set([repair_types[x.upper()] for x in additional_levels.split(",")]) + self.deactivated_levels = set() + if deactivated_levels: + self.deactivated_levels = set([repair_types[x.upper()] for x in deactivated_levels.split(",")]) + + def fix_error(self, pred_transition, model, state): + """ + Return which error has been made, if any, along with an updated transition list + + We assume the transition sequence builds a correct tree, meaning + that there will always be a CloseConstituent sometime after an + OpenConstituent, for example + """ + gold_transition = state.gold_sequence[state.num_transitions] + if gold_transition == pred_transition: + return self.repair_types.CORRECT, None + + for repair_type in self.repair_types: + if repair_type.fn is None: + continue + if self.oracle_level is not None and repair_type.value > self.oracle_level and repair_type not in self.additional_levels and not repair_type.debug: + continue + if repair_type in self.deactivated_levels: + continue + repair = repair_type.fn(gold_transition, pred_transition, state.gold_sequence, state.num_transitions, self.root_labels, model, state) + if repair is None: + continue + + if isinstance(repair, tuple) and len(repair) == 2: + return repair + + # TODO: could update all of the returns to be tuples of length 2 + if repair is not None: + return repair_type, repair + + return self.repair_types.UNKNOWN, None diff --git a/stanza/stanza/models/constituency/parse_transitions.py b/stanza/stanza/models/constituency/parse_transitions.py new file mode 100644 index 0000000000000000000000000000000000000000..f5840d301f8c58e3f18e8279fb130e03dbd68ed6 --- /dev/null +++ b/stanza/stanza/models/constituency/parse_transitions.py @@ -0,0 +1,641 @@ +""" +Defines a series of transitions (open a constituent, close a constituent, etc + +Also defines a State which holds the various data needed to build +a parse tree out of tagged words. +""" + +from abc import ABC, abstractmethod +import ast +from collections import defaultdict +from enum import Enum +import functools +import logging + +from stanza.models.constituency.parse_tree import Tree + +logger = logging.getLogger('stanza') + +class TransitionScheme(Enum): + def __new__(cls, value, short_name): + obj = object.__new__(cls) + obj._value_ = value + obj.short_name = short_name + return obj + + + # top down, so the open transition comes before any constituents + # score on vi_vlsp22 with 5 different sizes of bert layers, + # bert tagger, no silver dataset: + # 0.8171 + TOP_DOWN = 1, "top" + # unary transitions are modeled as one entire transition + # version that uses one transform per item, + # score on experiment described above: + # 0.8157 + # score using one combination step for an entire transition: + # 0.8178 + TOP_DOWN_COMPOUND = 2, "topc" + # unary is a separate transition. doesn't help + # score on experiment described above: + # 0.8128 + TOP_DOWN_UNARY = 3, "topu" + + # open transition comes after the first constituent it cares about + # score on experiment described above: + # 0.8205 + # note that this is with an oracle, whereas IN_ORDER_COMPOUND does + # not have a dynamic oracle, so there may be room for improvement + IN_ORDER = 4, "in" + + # in order, with unaries after preterminals represented as a single + # transition after the preterminal + # and unaries elsewhere tied to the rest of the constituent + # score: 0.8186 + IN_ORDER_COMPOUND = 5, "inc" + + # in order, with CompoundUnary on both preterminals and internal nodes + # score: 0.8166 + IN_ORDER_UNARY = 6, "inu" + +@functools.total_ordering +class Transition(ABC): + """ + model is passed in as a dependency injection + for example, an LSTM model can update hidden & output vectors when transitioning + """ + @abstractmethod + def update_state(self, state, model): + """ + update the word queue position, possibly remove old pieces from the constituents state, and return the new constituent + + the return value should be a tuple: + updated word_position + updated constituents + new constituent to put on the queue and None + - note that the constituent shouldn't be on the queue yet + that allows putting it on as a batch operation, which + saves a significant amount of time in an LSTM, for example + OR + data used to make a new constituent and the method used + - for example, CloseConstituent can return the children needed + and itself. this allows a batch operation to build + the constituent + """ + + def delta_opens(self): + return 0 + + def apply(self, state, model): + """ + return a new State transformed via this transition + + convenience method to call bulk_apply, which is significantly + faster than single operations for an NN based model + """ + update = model.bulk_apply([state], [self]) + return update[0] + + @abstractmethod + def is_legal(self, state, model): + """ + assess whether or not this transition is legal in this state + + at parse time, the parser might choose a transition which cannot be made + """ + + def components(self): + """ + Return a list of transitions which could theoretically make up this transition + + For example, an Open transition with multiple labels would + return a list of Opens with those labels + """ + return [self] + + @abstractmethod + def short_name(self): + """ + A short name to identify this transition + """ + + def short_label(self): + if not hasattr(self, "label"): + return self.short_name() + + if isinstance(self.label, str): + label = self.label + elif len(self.label) == 1: + label = self.label[0] + else: + label = self.label + return "{}({})".format(self.short_name(), label) + + def __lt__(self, other): + # put the Shift at the front of a list, and otherwise sort alphabetically + if self == other: + return False + if isinstance(self, Shift): + return True + if isinstance(other, Shift): + return False + return str(self) < str(other) + + + @staticmethod + def from_repr(desc): + """ + This method is to avoid using eval() or otherwise trying to + deserialize strings in a possibly untrusted manner when + loading from a checkpoint + """ + if desc == 'Shift': + return Shift() + if desc == 'CloseConstituent': + return CloseConstituent() + labels = desc.split("(", maxsplit=1) + if labels[0] not in ('CompoundUnary', 'OpenConstituent', 'Finalize'): + raise ValueError("Unknown Transition %s" % desc) + if len(labels) == 1: + raise ValueError("Unexpected Transition repr, %s needs labels" % labels[0]) + if labels[1][-1] != ')': + raise ValueError("Expected Transition repr for %s: %s(labels)" % (labels[0], labels[0])) + trans_type = labels[0] + labels = labels[1][:-1] + labels = ast.literal_eval(labels) + if trans_type == 'CompoundUnary': + return CompoundUnary(*labels) + if trans_type == 'OpenConstituent': + return OpenConstituent(*labels) + if trans_type == 'Finalize': + return Finalize(*labels) + raise ValueError("Unexpected Transition %s" % desc) + +class Shift(Transition): + def update_state(self, state, model): + """ + This will handle all aspects of a shift transition + + - push the top element of the word queue onto constituents + - pop the top element of the word queue + """ + new_constituent = model.transform_word_to_constituent(state) + return state.word_position+1, state.constituents, new_constituent, None + + def is_legal(self, state, model): + """ + Disallow shifting when the word queue is empty or there are no opens to eventually eat this word + """ + if state.empty_word_queue(): + return False + if model.is_top_down: + # top down transition sequences cannot shift if there are currently no + # Open transitions on the stack. in such a case, the new constituent + # will never be reduced + if state.num_opens == 0: + return False + if state.num_opens == 1: + # there must be at least one transition, since there is an open + assert state.transitions.parent is not None + if state.transitions.parent.parent is None: + # only one transition + trans = model.get_top_transition(state.transitions) + # must be an Open, since there is one open and one transitions + # note that an S, FRAG, etc could happen if we're using unary + # and ROOT-S is possible in the case of compound Open + # in both cases, Shift is legal + # Note that the corresponding problem of shifting after the ROOT-S + # has been closed to just ROOT is handled in CloseConstituent + if len(trans.label) == 1 and trans.top_label in model.root_labels: + # don't shift a word at the very start of a parse + # we want there to be an extra layer below ROOT + return False + else: + # in-order k==1 (the only other option currently) + # can shift ONCE, but note that there is no way to consume + # two items in a row if there is no Open on the stack. + # As long as there is one or more open transitions, + # everything can be eaten + if state.num_opens == 0: + if not state.empty_constituents: + return False + return True + + def short_name(self): + return "Shift" + + def __repr__(self): + return "Shift" + + def __eq__(self, other): + if self is other: + return True + if isinstance(other, Shift): + return True + return False + + def __hash__(self): + return hash(37) + +class CompoundUnary(Transition): + def __init__(self, *label): + # the FIRST label will be the top of the tree + # so CompoundUnary that results in root will have root as labels[0], for example + self.label = tuple(label) + + def update_state(self, state, model): + """ + Apply potentially multiple unary transitions to the same preterminal + + It reuses the CloseConstituent machinery + """ + # only the top constituent is meaningful here + constituents = state.constituents + children = [constituents.value] + constituents = constituents.pop() + # unlike with CloseConstituent, our label is not on the stack. + # it is just our label + # ... but we do reuse CloseConstituent's update mechanism + return state.word_position, constituents, (self.label, children), CloseConstituent + + def is_legal(self, state, model): + """ + Disallow consecutive CompoundUnary transitions, force final transition to go to ROOT + """ + # can't unary transition nothing + tree = model.get_top_constituent(state.constituents) + if tree is None: + return False + # don't unary transition a dummy, dummy + # and don't stack CompoundUnary transitions + if isinstance(model.get_top_transition(state.transitions), (CompoundUnary, OpenConstituent)): + return False + # if we are doing IN_ORDER_COMPOUND, then we are only using these + # transitions to model changes from a tag node to a sequence of + # unary nodes. can only occur at preterminals + if model.transition_scheme() is TransitionScheme.IN_ORDER_COMPOUND: + return tree.is_preterminal() + if model.transition_scheme() is not TransitionScheme.TOP_DOWN_UNARY: + return True + + is_root = self.label[0] in model.root_labels + if not state.empty_word_queue() or not state.has_one_constituent(): + return not is_root + else: + return is_root + + def components(self): + return [CompoundUnary(label) for label in self.label] + + def short_name(self): + return "Unary" + + def __repr__(self): + return "CompoundUnary(%s)" % ",".join(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, CompoundUnary): + return False + if self.label == other.label: + return True + return False + + def __hash__(self): + return hash(self.label) + +class Dummy(): + """ + Takes a space on the constituent stack to represent where an Open transition occurred + """ + def __init__(self, label): + self.label = label + + def is_preterminal(self): + return False + + def __format__(self, spec): + if spec is None or spec == '' or spec == 'O': + return "(%s ...)" % self.label + if spec == 'T': + return "\Tree [.%s ? ]" % self.label + raise ValueError("Unhandled spec: %s" % spec) + + def __str__(self): + return "Dummy({})".format(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, Dummy): + return False + if self.label == other.label: + return True + return False + + def __hash__(self): + return hash(self.label) + +def too_many_unary_nodes(tree, unary_limit): + """ + Return True iff there are UNARY_LIMIT unary nodes in a tree in a row + + helps prevent infinite open/close patterns + otherwise, the model can get stuck in essentially an infinite loop + """ + if tree is None: + return False + for _ in range(unary_limit + 1): + if len(tree.children) != 1: + return False + tree = tree.children[0] + return True + +class OpenConstituent(Transition): + def __init__(self, *label): + self.label = tuple(label) + self.top_label = self.label[0] + + def delta_opens(self): + return 1 + + def update_state(self, state, model): + # open a new constituent which can later be closed + # puts a DUMMY constituent on the stack to mark where the constituents end + return state.word_position, state.constituents, model.dummy_constituent(Dummy(self.label)), None + + def is_legal(self, state, model): + """ + disallow based on the length of the sentence + """ + if state.num_opens > state.sentence_length + 10: + # fudge a bit so we don't miss root nodes etc in very small trees + # also there's one really deep tree in CTB 9.0 + return False + if model.is_top_down: + # If the model is top down, you can't Open if there are + # no words to eventually eat + if state.empty_word_queue(): + return False + # Also, you can only Open a ROOT iff it is at the root position + # The assumption in the unary scheme is there will be no + # root open transitions + if not model.has_unary_transitions(): + # TODO: maybe cache this value if this is an expensive operation + is_root = self.top_label in model.root_labels + if is_root: + return state.empty_transitions() + else: + return not state.empty_transitions() + else: + # in-order nodes can Open as long as there is at least one thing + # on the constituency stack + # since closing the in-order involves removing one more + # item before the open, and it can close at any time + # (a close immediately after the open represents a unary) + if state.empty_constituents: + return False + if isinstance(model.get_top_transition(state.transitions), OpenConstituent): + # consecutive Opens don't make sense in the context of in-order + return False + if not model.transition_scheme() is TransitionScheme.IN_ORDER: + # eg, IN_ORDER_UNARY or IN_ORDER_COMPOUND + # if compound unary opens are used + # or the unary transitions are via CompoundUnary + # can always open as long as the word queue isn't empty + # if the word queue is empty, only close is allowed + return not state.empty_word_queue() + # one other restriction - we assume all parse trees + # start with (ROOT (first_real_con ...)) + # therefore ROOT can only occur via Open after everything + # else has been pushed and processed + # there are no further restrictions + is_root = self.top_label in model.root_labels + if is_root: + # can't make a root node if it will be in the middle of the parse + # can't make a root node if there's still words to eat + # note that the second assumption wouldn't work, + # except we are assuming there will never be multiple + # nodes under one root + return state.num_opens == 0 and state.empty_word_queue() + else: + if (state.num_opens > 0 or state.empty_word_queue()) and too_many_unary_nodes(model.get_top_constituent(state.constituents), model.unary_limit()): + # looks like we've been in a loop of lots of unary transitions + # note that we check `num_opens > 0` because otherwise we might wind up stuck + # in a state where the only legal transition is open, such as if the + # constituent stack is otherwise empty, but the open is illegal because + # it causes too many unaries + # in such a case we can forbid the corresponding close instead... + # if empty_word_queue, that means it is trying to make infinitiely many + # non-ROOT Open transitions instead of just transitioning ROOT + return False + return True + return True + + def components(self): + return [OpenConstituent(label) for label in self.label] + + def short_name(self): + return "Open" + + def __repr__(self): + return "OpenConstituent({})".format(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, OpenConstituent): + return False + if self.label == other.label: + return True + return False + + def __hash__(self): + return hash(self.label) + +class Finalize(Transition): + """ + Specifically applies at the end of a parse sequence to add a ROOT + + Seemed like the simplest way to remove ROOT from the + in_order_compound transitions while still using the mechanism of + the transitions to build the parse tree + """ + def __init__(self, *label): + self.label = tuple(label) + + def update_state(self, state, model): + """ + Apply potentially multiple unary transitions to the same preterminal + + Only applies to preterminals + It reuses the CloseConstituent machinery + """ + # only the top constituent is meaningful here + constituents = state.constituents + children = [constituents.value] + constituents = constituents.pop() + # unlike with CloseConstituent, our label is not on the stack. + # it is just our label + label = self.label + + # ... but we do reuse CloseConstituent's update + return state.word_position, constituents, (label, children), CloseConstituent + + def is_legal(self, state, model): + """ + Legal if & only if there is one tree, no more words, and no ROOT yet + """ + return state.empty_word_queue() and state.has_one_constituent() and not state.finished(model) + + def short_name(self): + return "Finalize" + + def __repr__(self): + return "Finalize(%s)" % ",".join(self.label) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, Finalize): + return False + return other.label == self.label + + def __hash__(self): + return hash((53, self.label)) + +class CloseConstituent(Transition): + def delta_opens(self): + return -1 + + def update_state(self, state, model): + # pop constituents until we are done + children = [] + constituents = state.constituents + while not isinstance(model.get_top_constituent(constituents), Dummy): + # keep the entire value from the stack - the model may need + # the whole thing to transform the children into a new node + children.append(constituents.value) + constituents = constituents.pop() + # the Dummy has the label on it + label = model.get_top_constituent(constituents).label + # pop past the Dummy as well + constituents = constituents.pop() + if not model.is_top_down: + # the alternative to TOP_DOWN_... is IN_ORDER + # in which case we want to pop one more constituent + children.append(constituents.value) + constituents = constituents.pop() + # the children are in the opposite order of what we expect + children.reverse() + + return state.word_position, constituents, (label, children), CloseConstituent + + @staticmethod + def build_constituents(model, data): + """ + builds new constituents out of the incoming data + + data is a list of tuples: (label, children) + the model will batch the build operation + again, the purpose of this batching is to do multiple deep learning operations at once + """ + labels, children_lists = map(list, zip(*data)) + new_constituents = model.build_constituents(labels, children_lists) + return new_constituents + + + def is_legal(self, state, model): + """ + Disallow if there is no Open on the stack yet + + in TOP_DOWN, if the previous transition was the Open (nothing built yet) + in IN_ORDER, previous transition does not matter, except for one small corner case + """ + if state.num_opens <= 0: + return False + if model.is_top_down: + if isinstance(model.get_top_transition(state.transitions), OpenConstituent): + return False + if state.num_opens <= 1 and not state.empty_word_queue(): + # don't close the last open until all words have been used + return False + if model.transition_scheme() == TransitionScheme.TOP_DOWN_COMPOUND: + # when doing TOP_DOWN_COMPOUND, we assume all transitions + # at the ROOT level have an S, SQ, FRAG, etc underneath + # this is checked when the model is first trained + if state.num_opens == 1 and not state.empty_word_queue(): + return False + elif not model.has_unary_transitions(): + # in fact, we have to leave the top level constituent + # under the ROOT open if unary transitions are not possible + if state.num_opens == 2 and not state.empty_word_queue(): + return False + elif model.transition_scheme() is TransitionScheme.IN_ORDER: + if not isinstance(model.get_top_transition(state.transitions), OpenConstituent): + # we're not stuck in a loop of unaries + return True + if state.num_opens > 1 or state.empty_word_queue(): + # in either of these cases, the corresponding Open should be eliminated + # if we're stuck in a loop of unaries + return True + node = model.get_top_constituent(state.constituents.pop()) + if too_many_unary_nodes(node, model.unary_limit()): + # at this point, we are in a situation where + # - multiple unaries have happened in a row + # - there is stuff on the word_queue, so a ROOT open isn't legal + # - there's only one constituent on the stack, so the only legal + # option once there are no opens left will be an open + # this means we'll be stuck having to open again if we do close + # this node, so instead we make the Close illegal + return False + else: + # model.transition_scheme() == TransitionScheme.IN_ORDER_COMPOUND or + # model.transition_scheme() == TransitionScheme.IN_ORDER_UNARY: + # in both of these cases, we cannot do open/close + # IN_ORDER_COMPOUND will use compound opens and preterminal unaries + # IN_ORDER_UNARY will use compound unaries + # the only restriction here is that we can't close immediately after an open + # internal unaries are handled by the opens being compound + # preterminal unaries are handled with CompoundUnary + if isinstance(model.get_top_transition(state.transitions), OpenConstituent): + return False + return True + + def short_name(self): + return "Close" + + def __repr__(self): + return "CloseConstituent" + + def __eq__(self, other): + if self is other: + return True + if isinstance(other, CloseConstituent): + return True + return False + + def __hash__(self): + return hash(93) + +def check_transitions(train_transitions, other_transitions, treebank_name): + """ + Check that all the transitions in the other dataset are known in the train set + + Weird nested unaries are warned rather than failed as long as the + components are all known + + There is a tree in VLSP, for example, with three (!) nested NP nodes + If this is an unknown compound transition, we won't possibly get it + right when parsing, but at least we don't need to fail + """ + unknown_transitions = set() + for trans in other_transitions: + if trans not in train_transitions: + for component in trans.components(): + if component not in train_transitions: + raise RuntimeError("Found transition {} in the {} set which don't exist in the train set".format(trans, treebank_name)) + unknown_transitions.add(trans) + if len(unknown_transitions) > 0: + logger.warning("Found transitions where the components are all valid transitions, but the complete transition is unknown: %s", sorted(unknown_transitions)) diff --git a/stanza/stanza/models/constituency/parser_training.py b/stanza/stanza/models/constituency/parser_training.py new file mode 100644 index 0000000000000000000000000000000000000000..71cea9aeedcde69dae0e50720a910abbcbf8d27b --- /dev/null +++ b/stanza/stanza/models/constituency/parser_training.py @@ -0,0 +1,771 @@ +from collections import Counter, namedtuple +import copy +import logging +import os +import random +import re + +import torch +from torch import nn + +#from stanza.models.common import pretrain + +from stanza.models.common import utils +from stanza.models.common.foundation_cache import FoundationCache, NoTransformerFoundationCache +from stanza.models.common.large_margin_loss import LargeMarginInSoftmaxLoss +from stanza.models.common.utils import sort_with_indices, unsort +from stanza.models.constituency import parse_transitions +from stanza.models.constituency import transition_sequence +from stanza.models.constituency import tree_reader +from stanza.models.constituency.in_order_compound_oracle import InOrderCompoundOracle +from stanza.models.constituency.in_order_oracle import InOrderOracle +from stanza.models.constituency.lstm_model import LSTMModel +from stanza.models.constituency.parse_transitions import TransitionScheme +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency.top_down_oracle import TopDownOracle +from stanza.models.constituency.trainer import Trainer +from stanza.models.constituency.utils import retag_trees, build_optimizer, build_scheduler, verify_transitions, get_open_nodes, check_constituents, check_root_labels, remove_duplicate_trees, remove_singleton_trees +from stanza.server.parser_eval import EvaluateParser, ParseResult +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +tlogger = logging.getLogger('stanza.constituency.trainer') + +TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals']) + +class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])): + def __add__(self, other): + transitions_correct = self.transitions_correct + other.transitions_correct + transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect + repairs_used = self.repairs_used + other.repairs_used + fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used + epoch_loss = self.epoch_loss + other.epoch_loss + nans = self.nans + other.nans + return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + +def evaluate(args, model_file, retag_pipeline): + """ + Loads the given model file and tests the eval_file treebank. + + May retag the trees using retag_pipeline + Uses a subprocess to run the Java EvalB code + """ + # we create the Evaluator here because otherwise the transformers + # library constantly complains about forking the process + # note that this won't help in the event of training multiple + # models in the same run, although since that would take hours + # or days, that's not a very common problem + if args['num_generate'] > 0: + kbest = args['num_generate'] + 1 + else: + kbest = None + + with EvaluateParser(kbest=kbest) as evaluator: + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() + load_args = { + 'wordvec_pretrain_file': args['wordvec_pretrain_file'], + 'charlm_forward_file': args['charlm_forward_file'], + 'charlm_backward_file': args['charlm_backward_file'], + 'device': args['device'], + } + trainer = Trainer.load(model_file, args=load_args, foundation_cache=foundation_cache) + + if args['log_shapes']: + trainer.log_shapes() + + treebank = tree_reader.read_treebank(args['eval_file']) + tlogger.info("Read %d trees for evaluation", len(treebank)) + + retagged_treebank = treebank + if retag_pipeline is not None: + retag_method = trainer.model.retag_method + retag_xpos = retag_method == 'xpos' + tlogger.info("Retagging trees using the %s tags from the %s package...", retag_method, args['retag_package']) + retagged_treebank = retag_trees(treebank, retag_pipeline, retag_xpos) + tlogger.info("Retagging finished") + + if args['log_norms']: + trainer.log_norms() + f1, kbestF1, _ = run_dev_set(trainer.model, retagged_treebank, treebank, args, evaluator) + tlogger.info("F1 score on %s: %f", args['eval_file'], f1) + if kbestF1 is not None: + tlogger.info("KBest F1 score on %s: %f", args['eval_file'], kbestF1) + +def remove_optimizer(args, model_save_file, model_load_file): + """ + A utility method to remove the optimizer from a save file + + Will make the save file a lot smaller + """ + # TODO: kind of overkill to load in the pretrain rather than + # change the load/save to work without it, but probably this + # functionality isn't used that often anyway + load_args = { + 'wordvec_pretrain_file': args['wordvec_pretrain_file'], + 'charlm_forward_file': args['charlm_forward_file'], + 'charlm_backward_file': args['charlm_backward_file'], + 'device': args['device'], + } + trainer = Trainer.load(model_load_file, args=load_args, load_optimizer=False) + trainer.save(model_save_file) + +def add_grad_clipping(trainer, grad_clipping): + """ + Adds a torch.clamp hook on each parameter if grad_clipping is not None + """ + if grad_clipping is not None: + for p in trainer.model.parameters(): + if p.requires_grad: + p.register_hook(lambda grad: torch.clamp(grad, -grad_clipping, grad_clipping)) + +def build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file): + """ + Builds a Trainer (with model) and the train_sequences and transitions for the given trees. + """ + train_constituents = Tree.get_unique_constituent_labels(train_trees) + tlogger.info("Unique constituents in training set: %s", train_constituents) + if args['check_valid_states']: + check_constituents(train_constituents, dev_trees, "dev", fail=args['strict_check_constituents']) + check_constituents(train_constituents, silver_trees, "silver", fail=args['strict_check_constituents']) + constituent_counts = Tree.get_constituent_counts(train_trees) + tlogger.info("Constituent node counts: %s", constituent_counts) + + tags = Tree.get_unique_tags(train_trees) + if None in tags: + raise RuntimeError("Fatal problem: the tagger put None on some of the nodes!") + tlogger.info("Unique tags in training set: %s", tags) + # no need to fail for missing tags between train/dev set + # the model has an unknown tag embedding + for tag in Tree.get_unique_tags(dev_trees): + if tag not in tags: + tlogger.info("Found tag in dev set which does not exist in train set: %s Continuing...", tag) + + unary_limit = max(max(t.count_unary_depth() for t in train_trees), + max(t.count_unary_depth() for t in dev_trees)) + 1 + if silver_trees: + unary_limit = max(unary_limit, max(t.count_unary_depth() for t in silver_trees)) + tlogger.info("Unary limit: %d", unary_limit) + train_sequences, train_transitions = transition_sequence.convert_trees_to_sequences(train_trees, "training", args['transition_scheme'], args['reversed']) + dev_sequences, dev_transitions = transition_sequence.convert_trees_to_sequences(dev_trees, "dev", args['transition_scheme'], args['reversed']) + silver_sequences, silver_transitions = transition_sequence.convert_trees_to_sequences(silver_trees, "silver", args['transition_scheme'], args['reversed']) + + tlogger.info("Total unique transitions in train set: %d", len(train_transitions)) + tlogger.info("Unique transitions in training set:\n %s", "\n ".join(map(str, train_transitions))) + expanded_train_transitions = set(train_transitions + [x for trans in train_transitions for x in trans.components()]) + if args['check_valid_states']: + parse_transitions.check_transitions(expanded_train_transitions, dev_transitions, "dev") + # theoretically could just train based on the items in the silver dataset + parse_transitions.check_transitions(expanded_train_transitions, silver_transitions, "silver") + + root_labels = Tree.get_root_labels(train_trees) + check_root_labels(root_labels, dev_trees, "dev") + check_root_labels(root_labels, silver_trees, "silver") + tlogger.info("Root labels in treebank: %s", root_labels) + + verify_transitions(train_trees, train_sequences, args['transition_scheme'], unary_limit, args['reversed'], "train", root_labels) + verify_transitions(dev_trees, dev_sequences, args['transition_scheme'], unary_limit, args['reversed'], "dev", root_labels) + + # we don't check against the words in the dev set as it is + # expected there will be some UNK words + words = Tree.get_unique_words(train_trees) + rare_words = Tree.get_rare_words(train_trees, args['rare_word_threshold']) + # rare/unknown silver words will just get UNK if they are not already known + if silver_trees and args['use_silver_words']: + tlogger.info("Getting silver words to add to the delta embedding") + silver_words = Tree.get_common_words(tqdm(silver_trees, postfix='Silver words'), len(words)) + words = sorted(set(words + silver_words)) + + # also, it's not actually an error if there is a pattern of + # compound unary or compound open nodes which doesn't exist in the + # train set. it just means we probably won't ever get that right + open_nodes = get_open_nodes(train_trees, args['transition_scheme']) + tlogger.info("Using the following open nodes:\n %s", "\n ".join(map(str, open_nodes))) + + # at this point we have: + # pretrain + # train_trees, dev_trees + # lists of transitions, internal nodes, and root states the parser needs to be aware of + + trainer = Trainer.build_trainer(args, train_transitions, train_constituents, tags, words, rare_words, root_labels, open_nodes, unary_limit, foundation_cache, model_load_file) + + trainer.log_num_words_known(words) + # grad clipping is not saved with the rest of the model, + # so even in the case of a model we saved, + # we now have to add the grad clipping + add_grad_clipping(trainer, args['grad_clipping']) + + return trainer, train_sequences, silver_sequences, train_transitions + +def train(args, model_load_file, retag_pipeline): + """ + Build a model, train it using the requested train & dev files + """ + utils.log_training_args(args, tlogger) + + # we create the Evaluator here because otherwise the transformers + # library constantly complains about forking the process + # note that this won't help in the event of training multiple + # models in the same run, although since that would take hours + # or days, that's not a very common problem + if args['num_generate'] > 0: + kbest = args['num_generate'] + 1 + else: + kbest = None + + if args['wandb']: + global wandb + import wandb + wandb_name = args['wandb_name'] if args['wandb_name'] else "%s_constituency" % args['shorthand'] + wandb.init(name=wandb_name, config=args) + wandb.run.define_metric('dev_score', summary='max') + + with EvaluateParser(kbest=kbest) as evaluator: + utils.ensure_dir(args['save_dir']) + + train_trees = tree_reader.read_treebank(args['train_file']) + tlogger.info("Read %d trees for the training set", len(train_trees)) + if args['train_remove_duplicates']: + train_trees = remove_duplicate_trees(train_trees, "train") + train_trees = remove_singleton_trees(train_trees) + + dev_trees = tree_reader.read_treebank(args['eval_file']) + tlogger.info("Read %d trees for the dev set", len(dev_trees)) + dev_trees = remove_duplicate_trees(dev_trees, "dev") + + silver_trees = [] + if args['silver_file']: + silver_trees = tree_reader.read_treebank(args['silver_file']) + tlogger.info("Read %d trees for the silver training set", len(silver_trees)) + if args['silver_remove_duplicates']: + silver_trees = remove_duplicate_trees(silver_trees, "silver") + + if retag_pipeline is not None: + tlogger.info("Retagging trees using the %s tags from the %s package...", args['retag_method'], args['retag_package']) + train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) + dev_trees = retag_trees(dev_trees, retag_pipeline, args['retag_xpos']) + silver_trees = retag_trees(silver_trees, retag_pipeline, args['retag_xpos']) + tlogger.info("Retagging finished") + + foundation_cache = retag_pipeline[0].foundation_cache if retag_pipeline else FoundationCache() + trainer, train_sequences, silver_sequences, train_transitions = build_trainer(args, train_trees, dev_trees, silver_trees, foundation_cache, model_load_file) + + if args['log_shapes']: + trainer.log_shapes() + trainer = iterate_training(args, trainer, train_trees, train_sequences, train_transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator) + + if args['wandb']: + wandb.finish() + + return trainer + +def compose_train_data(trees, sequences): + preterminal_lists = [[Tree(label=preterminal.label, children=Tree(label=preterminal.children[0].label)) + for preterminal in tree.yield_preterminals()] + for tree in trees] + data = [TrainItem(*x) for x in zip(trees, sequences, preterminal_lists)] + return data + +def next_epoch_data(leftover_training_data, train_data, epoch_size): + """ + Return the next epoch_size trees from the training data, starting + with leftover data from the previous epoch if there is any + + The training loop generally operates on a fixed number of trees, + rather than going through all the trees in the training set + exactly once, and keeping the leftover training data via this + function ensures that each tree in the training set is touched + once before beginning to iterate again. + """ + if not train_data: + return [], [] + + epoch_data = leftover_training_data + while len(epoch_data) < epoch_size: + random.shuffle(train_data) + epoch_data.extend(train_data) + leftover_training_data = epoch_data[epoch_size:] + epoch_data = epoch_data[:epoch_size] + + return leftover_training_data, epoch_data + +def update_bert_learning_rate(args, optimizer, epochs_trained): + """ + Update the learning rate for the bert finetuning, if applicable + """ + # would be nice to have a parameter group specific scheduler + # however, there is an issue with the optimizer we had the most success with, madgrad + # when the learning rate is 0 for a group, it still learns by some + # small amount because of the eps parameter + # in fact, that is enough to make the learning for the bert in the + # second half broken + for base_param_group in optimizer.param_groups: + if base_param_group['param_group_name'] == 'base': + break + else: + raise AssertionError("There should always be a base parameter group") + for param_group in optimizer.param_groups: + if param_group['param_group_name'] == 'bert': + # Occasionally a model goes haywire and forgets how to use the transformer + # So far we have only seen this happen with Electra on the non-NML version of PTB + # We tried fixing that with an increasing transformer learning rate, but that + # didn't fully resolve the problem + # Switching to starting the finetuning after a few epochs seems to help a lot, though + old_lr = param_group['lr'] + if args['bert_finetune_begin_epoch'] is not None and epochs_trained < args['bert_finetune_begin_epoch']: + param_group['lr'] = 0.0 + elif args['bert_finetune_end_epoch'] is not None and epochs_trained >= args['bert_finetune_end_epoch']: + param_group['lr'] = 0.0 + elif args['multistage'] and epochs_trained < args['epochs'] // 2: + param_group['lr'] = base_param_group['lr'] * args['stage1_bert_learning_rate'] + else: + param_group['lr'] = base_param_group['lr'] * args['bert_learning_rate'] + if param_group['lr'] != old_lr: + tlogger.info("Setting %s finetuning rate from %f to %f", param_group['param_group_name'], old_lr, param_group['lr']) + +def iterate_training(args, trainer, train_trees, train_sequences, transitions, dev_trees, silver_trees, silver_sequences, foundation_cache, evaluator): + """ + Given an initialized model, a processed dataset, and a secondary dev dataset, train the model + + The training is iterated in the following loop: + extract a batch of trees of the same length from the training set + convert those trees into initial parsing states + repeat until trees are done: + batch predict the model's interpretation of the current states + add the errors to the list of things to backprop + advance the parsing state for each of the trees + """ + # Somewhat unusual, but possibly related to the extreme variability in length of trees + # Various experiments generally show about 0.5 F1 loss on various + # datasets when using 'mean' instead of 'sum' for reduction + # (Remember to adjust the weight decay when rerunning that experiment) + if args['loss'] == 'cross': + tlogger.info("Building CrossEntropyLoss(sum)") + process_outputs = lambda x: x + model_loss_function = nn.CrossEntropyLoss(reduction='sum') + elif args['loss'] == 'focal': + try: + from focal_loss.focal_loss import FocalLoss + except ImportError: + raise ImportError("focal_loss not installed. Must `pip install focal_loss_torch` to use the --loss=focal feature") + tlogger.info("Building FocalLoss, gamma=%f", args['loss_focal_gamma']) + process_outputs = lambda x: torch.softmax(x, dim=1) + model_loss_function = FocalLoss(reduction='sum', gamma=args['loss_focal_gamma']) + elif args['loss'] == 'large_margin': + tlogger.info("Building LargeMarginInSoftmaxLoss(sum)") + process_outputs = lambda x: x + model_loss_function = LargeMarginInSoftmaxLoss(reduction='sum') + else: + raise ValueError("Unexpected loss term: %s" % args['loss']) + + device = trainer.device + model_loss_function.to(device) + transition_tensors = {x: torch.tensor(y, requires_grad=False, device=device).unsqueeze(0) + for (y, x) in enumerate(trainer.transitions)} + trainer.train() + + train_data = compose_train_data(train_trees, train_sequences) + silver_data = compose_train_data(silver_trees, silver_sequences) + + if not args['epoch_size']: + args['epoch_size'] = len(train_data) + if silver_data and not args['silver_epoch_size']: + args['silver_epoch_size'] = args['epoch_size'] + + if args['multistage']: + multistage_splits = {} + # if we're halfway, only do pattn. save lattn for next time + multistage_splits[args['epochs'] // 2] = (args['pattn_num_layers'], False) + if LSTMModel.uses_lattn(args): + multistage_splits[args['epochs'] * 3 // 4] = (args['pattn_num_layers'], True) + + oracle = None + if args['transition_scheme'] is TransitionScheme.IN_ORDER: + oracle = InOrderOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) + elif args['transition_scheme'] is TransitionScheme.IN_ORDER_COMPOUND: + oracle = InOrderCompoundOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) + elif args['transition_scheme'] is TransitionScheme.TOP_DOWN: + oracle = TopDownOracle(trainer.root_labels, args['oracle_level'], args['additional_oracle_levels'], args['deactivated_oracle_levels']) + + leftover_training_data = [] + leftover_silver_data = [] + if trainer.best_epoch > 0: + tlogger.info("Restarting trainer with a model trained for %d epochs. Best epoch %d, f1 %f", trainer.epochs_trained, trainer.best_epoch, trainer.best_f1) + + # if we're training a new model, save the initial state so it can be inspected + if args['save_each_start'] == 0 and trainer.epochs_trained == 0: + trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=True) + + # trainer.epochs_trained+1 so that if the trainer gets saved after 1 epoch, the epochs_trained is 1 + for trainer.epochs_trained in range(trainer.epochs_trained+1, args['epochs']+1): + trainer.train() + tlogger.info("Starting epoch %d", trainer.epochs_trained) + update_bert_learning_rate(args, trainer.optimizer, trainer.epochs_trained) + + if args['log_norms']: + trainer.log_norms() + leftover_training_data, epoch_data = next_epoch_data(leftover_training_data, train_data, args['epoch_size']) + leftover_silver_data, epoch_silver_data = next_epoch_data(leftover_silver_data, silver_data, args['silver_epoch_size']) + epoch_data = epoch_data + epoch_silver_data + epoch_data.sort(key=lambda x: len(x[1])) + + epoch_stats = train_model_one_epoch(trainer.epochs_trained, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args) + + # print statistics + # by now we've forgotten about the original tags on the trees, + # but it doesn't matter for hill climbing + f1, _, _ = run_dev_set(trainer.model, dev_trees, dev_trees, args, evaluator) + if f1 > trainer.best_f1 or (trainer.best_epoch == 0 and trainer.best_f1 == 0.0): + # best_epoch == 0 to force a save of an initial model + # useful for tests which expect something, even when a + # very simple model didn't learn anything + tlogger.info("New best dev score: %.5f > %.5f", f1, trainer.best_f1) + trainer.best_f1 = f1 + trainer.best_epoch = trainer.epochs_trained + trainer.save(args['save_name'], save_optimizer=False) + if epoch_stats.nans > 0: + tlogger.warning("Had to ignore %d batches with NaN", epoch_stats.nans) + stats_log_lines = [ + "Epoch %d finished" % trainer.epochs_trained, + "Transitions correct: %s" % epoch_stats.transitions_correct, + "Transitions incorrect: %s" % epoch_stats.transitions_incorrect, + "Total loss for epoch: %.5f" % epoch_stats.epoch_loss, + "Dev score (%5d): %8f" % (trainer.epochs_trained, f1), + "Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1) + ] + tlogger.info("\n ".join(stats_log_lines)) + + old_lr = trainer.optimizer.param_groups[0]['lr'] + trainer.scheduler.step(f1) + new_lr = trainer.optimizer.param_groups[0]['lr'] + if old_lr != new_lr: + tlogger.info("Updating learning rate from %f to %f", old_lr, new_lr) + + if args['wandb']: + wandb.log({'epoch_loss': epoch_stats.epoch_loss, 'dev_score': f1}, step=trainer.epochs_trained) + if args['wandb_norm_regex']: + watch_regex = re.compile(args['wandb_norm_regex']) + for n, p in trainer.model.named_parameters(): + if watch_regex.search(n): + wandb.log({n: torch.linalg.norm(p)}) + + if args['early_dropout'] > 0 and trainer.epochs_trained >= args['early_dropout']: + if any(x > 0.0 for x in (trainer.model.word_dropout.p, trainer.model.predict_dropout.p, trainer.model.lstm_input_dropout.p)): + tlogger.info("Setting dropout to 0.0 at epoch %d", trainer.epochs_trained) + trainer.model.word_dropout.p = 0 + trainer.model.predict_dropout.p = 0 + trainer.model.lstm_input_dropout.p = 0 + + # recreate the optimizer and alter the model as needed if we hit a new multistage split + if args['multistage'] and trainer.epochs_trained in multistage_splits: + # we may be loading a save model from an earlier epoch if the scores stopped increasing + epochs_trained = trainer.epochs_trained + batches_trained = trainer.batches_trained + + stage_pattn_layers, stage_uses_lattn = multistage_splits[epochs_trained] + + # when loading the model, let the saved model determine whether it has pattn or lattn + temp_args = copy.deepcopy(trainer.model.args) + temp_args.pop('pattn_num_layers', None) + temp_args.pop('lattn_d_proj', None) + # overwriting the old trainer & model will hopefully free memory + # load a new bert, even in PEFT mode, mostly so that the bert model + # doesn't collect a whole bunch of PEFTs + # for one thing, two PEFTs would mean 2x the optimizer parameters, + # messing up saving and loading the optimizer without jumping + # through more hoops + # loading the trainer w/o the foundation_cache should create + # the necessary bert_model and bert_tokenizer, and then we + # can reuse those values when building out new LSTMModel + trainer = Trainer.load(args['save_name'], temp_args, load_optimizer=False) + model = trainer.model + tlogger.info("Finished stage at epoch %d. Restarting optimizer", epochs_trained) + tlogger.info("Previous best model was at epoch %d", trainer.epochs_trained) + + temp_args = dict(args) + tlogger.info("Switching to a model with %d pattn layers and %slattn", stage_pattn_layers, "" if stage_uses_lattn else "NO ") + temp_args['pattn_num_layers'] = stage_pattn_layers + if not stage_uses_lattn: + temp_args['lattn_d_proj'] = 0 + pt = foundation_cache.load_pretrain(args['wordvec_pretrain_file']) + forward_charlm = foundation_cache.load_charlm(args['charlm_forward_file']) + backward_charlm = foundation_cache.load_charlm(args['charlm_backward_file']) + new_model = LSTMModel(pt, + forward_charlm, + backward_charlm, + model.bert_model, + model.bert_tokenizer, + model.force_bert_saved, + model.peft_name, + model.transitions, + model.constituents, + model.tags, + model.delta_words, + model.rare_words, + model.root_labels, + model.constituent_opens, + model.unary_limit(), + temp_args) + new_model.to(device) + new_model.copy_with_new_structure(model) + + optimizer = build_optimizer(temp_args, new_model, False) + scheduler = build_scheduler(temp_args, optimizer) + trainer = Trainer(new_model, optimizer, scheduler, epochs_trained, batches_trained, trainer.best_f1, trainer.best_epoch) + add_grad_clipping(trainer, args['grad_clipping']) + + # checkpoint needs to be saved AFTER rebuilding the optimizer + # so that assumptions about the optimizer in the checkpoint + # can be made based on the end of the epoch + if args['checkpoint'] and args['checkpoint_save_name']: + trainer.save(args['checkpoint_save_name'], save_optimizer=True) + # same with the "each filename", actually, in case those are + # brought back for more training or even just for testing + if args['save_each_start'] is not None and args['save_each_start'] <= trainer.epochs_trained and trainer.epochs_trained % args['save_each_frequency'] == 0: + trainer.save(args['save_each_name'] % trainer.epochs_trained, save_optimizer=args['save_each_optimizer']) + + return trainer + +def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, model_loss_function, epoch_data, oracle, args): + interval_starts = list(range(0, len(epoch_data), args['train_batch_size'])) + random.shuffle(interval_starts) + + optimizer = trainer.optimizer + + epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0) + + for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)): + batch = epoch_data[interval_start:interval_start+args['train_batch_size']] + batch_stats = train_model_one_batch(epoch, batch_idx, trainer.model, batch, transition_tensors, process_outputs, model_loss_function, oracle, args) + trainer.batches_trained += 1 + + # Early in the training, some trees will be degenerate in a + # way that results in layers going up the tree amplifying the + # weights until they overflow. Generally that problem + # resolves itself in a few iterations, so for now we just + # ignore those batches, but report how often it happens + if batch_stats.nans == 0: + optimizer.step() + optimizer.zero_grad() + epoch_stats = epoch_stats + batch_stats + + + # TODO: refactor the logging? + total_correct = sum(v for _, v in epoch_stats.transitions_correct.items()) + total_incorrect = sum(v for _, v in epoch_stats.transitions_incorrect.items()) + tlogger.info("Transitions correct: %d\n %s", total_correct, str(epoch_stats.transitions_correct)) + tlogger.info("Transitions incorrect: %d\n %s", total_incorrect, str(epoch_stats.transitions_incorrect)) + if len(epoch_stats.repairs_used) > 0: + tlogger.info("Oracle repairs:\n %s", "\n ".join("%s (%s): %d" % (x.name, x.value, y) for x, y in epoch_stats.repairs_used.most_common())) + if epoch_stats.fake_transitions_used > 0: + tlogger.info("Fake transitions used: %d", epoch_stats.fake_transitions_used) + + return epoch_stats + +def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_tensors, process_outputs, model_loss_function, oracle, args): + """ + Train the model for one batch + + The model itself will be updated, and a bunch of stats are returned + It is unclear if this refactoring is useful in any way. Might not be + + ... although the indentation does get pretty ridiculous if this is + merged into train_model_one_epoch and then iterate_training + """ + # now we add the state to the trees in the batch + # the state is built as a bulk operation + current_batch = model.initial_state_from_preterminals([x.preterminals for x in training_batch], + [x.tree for x in training_batch], + [x.gold_sequence for x in training_batch]) + + transitions_correct = Counter() + transitions_incorrect = Counter() + repairs_used = Counter() + fake_transitions_used = 0 + + all_errors = [] + all_answers = [] + + # we iterate through the batch in the following sequence: + # predict the logits and the applied transition for each tree in the batch + # collect errors + # - we always train to the desired one-hot vector + # this was a noticeable improvement over training just the + # incorrect transitions + # determine whether the training can continue using the "student" transition + # or if we need to use teacher forcing + # update all states using either the gold or predicted transition + # any trees which are now finished are removed from the training cycle + while len(current_batch) > 0: + outputs, pred_transitions, _ = model.predict(current_batch, is_legal=False) + gold_transitions = [x.gold_sequence[x.num_transitions] for x in current_batch] + trans_tensor = [transition_tensors[gold_transition] for gold_transition in gold_transitions] + all_errors.append(outputs) + all_answers.extend(trans_tensor) + + new_batch = [] + update_transitions = [] + for pred_transition, gold_transition, state in zip(pred_transitions, gold_transitions, current_batch): + # forget teacher forcing vs scheduled sampling + # we're going with idiot forcing + if pred_transition == gold_transition: + transitions_correct[gold_transition.short_name()] += 1 + if state.num_transitions + 1 < len(state.gold_sequence): + if oracle is not None and epoch >= args['oracle_initial_epoch'] and random.random() < args['oracle_forced_errors']: + # TODO: could randomly choose from the legal transitions + # perhaps the second best scored transition + fake_transition = random.choice(model.transitions) + if fake_transition.is_legal(state, model): + _, new_sequence = oracle.fix_error(fake_transition, model, state) + if new_sequence is not None: + new_batch.append(state._replace(gold_sequence=new_sequence)) + update_transitions.append(fake_transition) + fake_transitions_used = fake_transitions_used + 1 + continue + new_batch.append(state) + update_transitions.append(gold_transition) + continue + + transitions_incorrect[gold_transition.short_name(), pred_transition.short_name()] += 1 + # if we are on the final operation, there are two choices: + # - the parsing mode is IN_ORDER, and the final transition + # is the close to end the sequence, which has no alternatives + # - the parsing mode is something else, in which case + # we have no oracle anyway + if state.num_transitions + 1 >= len(state.gold_sequence): + continue + + if oracle is None or epoch < args['oracle_initial_epoch'] or not pred_transition.is_legal(state, model): + new_batch.append(state) + update_transitions.append(gold_transition) + continue + + repair_type, new_sequence = oracle.fix_error(pred_transition, model, state) + # we can only reach here on an error + assert not repair_type.is_correct + repairs_used[repair_type] += 1 + if new_sequence is not None and random.random() < args['oracle_frequency']: + new_batch.append(state._replace(gold_sequence=new_sequence)) + update_transitions.append(pred_transition) + else: + new_batch.append(state) + update_transitions.append(gold_transition) + + if len(current_batch) > 0: + # bulk update states - significantly faster + current_batch = model.bulk_apply(new_batch, update_transitions, fail=True) + + errors = torch.cat(all_errors) + answers = torch.cat(all_answers) + + errors = process_outputs(errors) + tree_loss = model_loss_function(errors, answers) + tree_loss.backward() + if args['watch_regex']: + matched = False + tlogger.info("Watching %s ... epoch %d batch %d", args['watch_regex'], epoch, batch_idx) + watch_regex = re.compile(args['watch_regex']) + for n, p in trainer.model.named_parameters(): + if watch_regex.search(n): + matched = True + if p.requires_grad and p.grad is not None: + tlogger.info(" %s norm: %f grad: %f", n, torch.linalg.norm(p), torch.linalg.norm(p.grad)) + elif p.requires_grad: + tlogger.info(" %s norm: %f grad required, but is None!", n, torch.linalg.norm(p)) + else: + tlogger.info(" %s norm: %f grad not required", n, torch.linalg.norm(p)) + if not matched: + tlogger.info(" (none found!)") + if torch.any(torch.isnan(tree_loss)): + batch_loss = 0.0 + nans = 1 + else: + batch_loss = tree_loss.item() + nans = 0 + + return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans) + +def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None): + """ + This reparses a treebank and executes the CoreNLP Java EvalB code. + + It only works if CoreNLP 4.3.0 or higher is in the classpath. + """ + tlogger.info("Processing %d trees from %s", len(retagged_trees), args['eval_file']) + model.eval() + + num_generate = args.get('num_generate', 0) + keep_scores = num_generate > 0 + + sorted_trees, original_indices = sort_with_indices(retagged_trees, key=len, reverse=True) + tree_iterator = iter(tqdm(sorted_trees)) + treebank = model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.predict, keep_scores=keep_scores) + treebank = unsort(treebank, original_indices) + full_results = treebank + + if num_generate > 0: + tlogger.info("Generating %d random analyses", args['num_generate']) + generated_treebanks = [treebank] + for i in tqdm(range(num_generate)): + tree_iterator = iter(tqdm(retagged_trees, leave=False, postfix="tb%03d" % i)) + generated_treebanks.append(model.parse_sentences_no_grad(tree_iterator, model.build_batch_from_trees, args['eval_batch_size'], model.weighted_choice, keep_scores=keep_scores)) + + #best_treebank = [ParseResult(parses[0].gold, [max([p.predictions[0] for p in parses], key=itemgetter(1))], None, None) + # for parses in zip(*generated_treebanks)] + #generated_treebanks = [best_treebank] + generated_treebanks + + # TODO: if the model is dropping trees, this will not work + full_results = [ParseResult(parses[0].gold, [p.predictions[0] for p in parses], None, None) + for parses in zip(*generated_treebanks)] + + if len(full_results) < len(retagged_trees): + tlogger.warning("Only evaluating %d trees instead of %d", len(full_results), len(retagged_trees)) + else: + full_results = [x._replace(gold=gold) for x, gold in zip(full_results, original_trees)] + + if args.get('mode', None) == 'predict' and args['predict_file']: + utils.ensure_dir(args['predict_dir'], verbose=False) + pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".pred.mrg") + orig_file = os.path.join(args['predict_dir'], args['predict_file'] + ".orig.mrg") + if os.path.exists(pred_file): + tlogger.warning("Cowardly refusing to overwrite {}".format(pred_file)) + elif os.path.exists(orig_file): + tlogger.warning("Cowardly refusing to overwrite {}".format(orig_file)) + else: + with open(pred_file, 'w') as fout: + for tree in full_results: + output_tree = tree.predictions[0].tree + if args['predict_output_gold_tags']: + output_tree = output_tree.replace_tags(tree.gold) + fout.write(args['predict_format'].format(output_tree)) + fout.write("\n") + + for i in range(num_generate): + pred_file = os.path.join(args['predict_dir'], args['predict_file'] + ".%03d.pred.mrg" % i) + with open(pred_file, 'w') as fout: + for tree in generated_treebanks[-(i+1)]: + output_tree = tree.predictions[0].tree + if args['predict_output_gold_tags']: + output_tree = output_tree.replace_tags(tree.gold) + fout.write(args['predict_format'].format(output_tree)) + fout.write("\n") + + with open(orig_file, 'w') as fout: + for tree in full_results: + fout.write(args['predict_format'].format(tree.gold)) + fout.write("\n") + + if len(full_results) == 0: + return 0.0, 0.0 + if evaluator is None: + if num_generate > 0: + kbest = max(len(fr.predictions) for fr in full_results) + else: + kbest = None + with EvaluateParser(kbest=kbest) as evaluator: + response = evaluator.process(full_results) + else: + response = evaluator.process(full_results) + + kbestF1 = response.kbestF1 if response.HasField("kbestF1") else None + return response.f1, kbestF1, response.treeF1 diff --git a/stanza/stanza/models/constituency/partitioned_transformer.py b/stanza/stanza/models/constituency/partitioned_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b46b4338793e685f53ab45254573a6bf7dace78b --- /dev/null +++ b/stanza/stanza/models/constituency/partitioned_transformer.py @@ -0,0 +1,308 @@ +""" +Transformer with partitioned content and position features. + +See section 3 of https://arxiv.org/pdf/1805.01052.pdf +""" + +import copy +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from stanza.models.constituency.positional_encoding import ConcatSinusoidalEncoding + +class FeatureDropoutFunction(torch.autograd.function.InplaceFunction): + @staticmethod + def forward(ctx, input, p=0.5, train=False, inplace=False): + if p < 0 or p > 1: + raise ValueError( + "dropout probability has to be between 0 and 1, but got {}".format(p) + ) + + ctx.p = p + ctx.train = train + ctx.inplace = inplace + + if ctx.inplace: + ctx.mark_dirty(input) + output = input + else: + output = input.clone() + + if ctx.p > 0 and ctx.train: + ctx.noise = torch.empty( + (input.size(0), input.size(-1)), + dtype=input.dtype, + layout=input.layout, + device=input.device, + ) + if ctx.p == 1: + ctx.noise.fill_(0) + else: + ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p) + ctx.noise = ctx.noise[:, None, :] + output.mul_(ctx.noise) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.p > 0 and ctx.train: + return grad_output.mul(ctx.noise), None, None, None + else: + return grad_output, None, None, None + + +class FeatureDropout(nn.Dropout): + """ + Feature-level dropout: takes an input of size len x num_features and drops + each feature with probabibility p. A feature is dropped across the full + portion of the input that corresponds to a single batch element. + """ + + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace) + x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace) + return x_c, x_p + else: + return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace) + + +# TODO: this module apparently is not treated the same the built-in +# nonlinearity modules, as multiple uses of the same relu on different +# tensors winds up mixing the gradients See if there is a way to +# resolve that other than creating a new nonlinearity for each layer +class PartitionedReLU(nn.ReLU): + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + return super().forward(x_c), super().forward(x_p) + + +class PartitionedLinear(nn.Module): + def __init__(self, in_features, out_features, bias=True): + super().__init__() + self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias) + self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias) + + def forward(self, x): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + + out_c = self.linear_c(x_c) + out_p = self.linear_p(x_p) + return out_c, out_p + + +class PartitionedMultiHeadAttention(nn.Module): + def __init__( + self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02 + ): + super().__init__() + + self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) + self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) + self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) + self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) + + bound = math.sqrt(3.0) * initializer_range + for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]: + nn.init.uniform_(param, -bound, bound) + self.scaling_factor = 1 / d_qkv ** 0.5 + + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, x, mask=None): + if isinstance(x, tuple): + x_c, x_p = x + else: + x_c, x_p = torch.chunk(x, 2, dim=-1) + qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c) + qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p) + q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)] + q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)] + q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor + k = torch.cat([k_c, k_p], dim=-1) + v = torch.cat([v_c, v_p], dim=-1) + dots = torch.einsum("bhqa,bhka->bhqk", q, k) + if mask is not None: + dots.data.masked_fill_(~mask[:, None, None, :], -float("inf")) + probs = F.softmax(dots, dim=-1) + probs = self.dropout(probs) + o = torch.einsum("bhqk,bhka->bhqa", probs, v) + o_c, o_p = torch.chunk(o, 2, dim=-1) + out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c) + out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p) + return out_c, out_p + + +class PartitionedTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model, + n_head, + d_qkv, + d_ff, + ff_dropout, + residual_dropout, + attention_dropout, + activation=PartitionedReLU(), + ): + super().__init__() + self.self_attn = PartitionedMultiHeadAttention( + d_model, n_head, d_qkv, attention_dropout=attention_dropout + ) + self.linear1 = PartitionedLinear(d_model, d_ff) + self.ff_dropout = FeatureDropout(ff_dropout) + self.linear2 = PartitionedLinear(d_ff, d_model) + + self.norm_attn = nn.LayerNorm(d_model) + self.norm_ff = nn.LayerNorm(d_model) + self.residual_dropout_attn = FeatureDropout(residual_dropout) + self.residual_dropout_ff = FeatureDropout(residual_dropout) + + self.activation = activation + + def forward(self, x, mask=None): + residual = self.self_attn(x, mask=mask) + residual = torch.cat(residual, dim=-1) + residual = self.residual_dropout_attn(residual) + x = self.norm_attn(x + residual) + residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x)))) + residual = torch.cat(residual, dim=-1) + residual = self.residual_dropout_ff(residual) + x = self.norm_ff(x + residual) + return x + + +class PartitionedTransformerEncoder(nn.Module): + def __init__(self, + n_layers, + d_model, + n_head, + d_qkv, + d_ff, + ff_dropout, + residual_dropout, + attention_dropout, + activation=PartitionedReLU, + ): + super().__init__() + self.layers = nn.ModuleList([PartitionedTransformerEncoderLayer(d_model=d_model, + n_head=n_head, + d_qkv=d_qkv, + d_ff=d_ff, + ff_dropout=ff_dropout, + residual_dropout=residual_dropout, + attention_dropout=attention_dropout, + activation=activation()) + for i in range(n_layers)]) + + def forward(self, x, mask=None): + for layer in self.layers: + x = layer(x, mask=mask) + return x + + +class ConcatPositionalEncoding(nn.Module): + """ + Learns a position embedding + """ + def __init__(self, d_model=256, max_len=512): + super().__init__() + self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model)) + nn.init.normal_(self.timing_table) + + def forward(self, x): + timing = self.timing_table[:x.shape[1], :] + timing = timing.expand(x.shape[0], -1, -1) + out = torch.cat([x, timing], dim=-1) + return out + +# +class PartitionedTransformerModule(nn.Module): + def __init__(self, + n_layers, + d_model, + n_head, + d_qkv, + d_ff, + ff_dropout, + residual_dropout, + attention_dropout, + word_input_size, + bias, + morpho_emb_dropout, + timing, + encoder_max_len, + activation=PartitionedReLU() + ): + super().__init__() + self.project_pretrained = nn.Linear( + word_input_size, d_model // 2, bias=bias + ) + + self.pattention_morpho_emb_dropout = FeatureDropout(morpho_emb_dropout) + if timing == 'sin': + self.add_timing = ConcatSinusoidalEncoding(d_model=d_model // 2, max_len=encoder_max_len) + elif timing == 'learned': + self.add_timing = ConcatPositionalEncoding(d_model=d_model // 2, max_len=encoder_max_len) + else: + raise ValueError("Unhandled timing type: %s" % timing) + self.transformer_input_norm = nn.LayerNorm(d_model) + self.pattn_encoder = PartitionedTransformerEncoder( + n_layers, + d_model=d_model, + n_head=n_head, + d_qkv=d_qkv, + d_ff=d_ff, + ff_dropout=ff_dropout, + residual_dropout=residual_dropout, + attention_dropout=attention_dropout, + ) + + + # + def forward(self, attention_mask, bert_embeddings): + # Prepares attention mask for feeding into the self-attention + device = bert_embeddings[0].device + if attention_mask: + valid_token_mask = attention_mask + else: + valids = [] + for sent in bert_embeddings: + valids.append(torch.ones(len(sent), device=device)) + + padded_data = torch.nn.utils.rnn.pad_sequence( + valids, + batch_first=True, + padding_value=-100 + ) + + valid_token_mask = padded_data != -100 + + valid_token_mask = valid_token_mask.to(device=device) + padded_embeddings = torch.nn.utils.rnn.pad_sequence( + bert_embeddings, + batch_first=True, + padding_value=0 + ) + + # Project the pretrained embedding onto the desired dimension + extra_content_annotations = self.project_pretrained(padded_embeddings) + + # Add positional information through the table + encoder_in = self.add_timing(self.pattention_morpho_emb_dropout(extra_content_annotations)) + encoder_in = self.transformer_input_norm(encoder_in) + # Put the partitioned input through the partitioned attention + annotations = self.pattn_encoder(encoder_in, valid_token_mask) + + return annotations + diff --git a/stanza/stanza/models/coref/__init__.py b/stanza/stanza/models/coref/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/models/coref/anaphoricity_scorer.py b/stanza/stanza/models/coref/anaphoricity_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..fef3299c49d2d3cd03020d89706a0f5fa796ae1e --- /dev/null +++ b/stanza/stanza/models/coref/anaphoricity_scorer.py @@ -0,0 +1,122 @@ +""" Describes AnaphicityScorer, a torch module that for a matrix of +mentions produces their anaphoricity scores. +""" +import torch + +from stanza.models.coref import utils +from stanza.models.coref.config import Config + + +class AnaphoricityScorer(torch.nn.Module): + """ Calculates anaphoricity scores by passing the inputs into a FFNN """ + def __init__(self, + in_features: int, + config: Config): + super().__init__() + hidden_size = config.hidden_size + if not config.n_hidden_layers: + hidden_size = in_features + layers = [] + for i in range(config.n_hidden_layers): + layers.extend([torch.nn.Linear(hidden_size if i else in_features, + hidden_size), + torch.nn.LeakyReLU(), + torch.nn.Dropout(config.dropout_rate)]) + self.hidden = torch.nn.Sequential(*layers) + self.out = torch.nn.Linear(hidden_size, out_features=1) + + # are we going to predict singletons + self.predict_singletons = config.singletons + + if self.predict_singletons: + # map to whether or not this is a start of a coref given all the + # antecedents; not used when config.singletons = False because + # we only need to know this for predicting singletons + self.start_map = torch.nn.Linear(config.rough_k, out_features=1, bias=False) + + + def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + top_mentions: torch.Tensor, + mentions_batch: torch.Tensor, + pw_batch: torch.Tensor, + top_rough_scores_batch: torch.Tensor, + ) -> torch.Tensor: + """ Builds a pairwise matrix, scores the pairs and returns the scores. + + Args: + all_mentions (torch.Tensor): [n_mentions, mention_emb] + mentions_batch (torch.Tensor): [batch_size, mention_emb] + pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb] + top_indices_batch (torch.Tensor): [batch_size, n_ants] + top_rough_scores_batch (torch.Tensor): [batch_size, n_ants] + + Returns: + torch.Tensor [batch_size, n_ants + 1] + anaphoricity scores for the pairs + a dummy column + """ + # [batch_size, n_ants, pair_emb] + pair_matrix = self._get_pair_matrix(mentions_batch, pw_batch, top_mentions) + + # [batch_size, n_ants] vs [batch_size, 1] + # first is coref scores, the second is whether its the start of a coref + if self.predict_singletons: + scores, start = self._ffnn(pair_matrix) + scores = utils.add_dummy(scores+top_rough_scores_batch, eps=True) + + return torch.cat([start, scores], dim=1) + else: + scores = self._ffnn(pair_matrix) + return utils.add_dummy(scores+top_rough_scores_batch, eps=True) + + def _ffnn(self, x: torch.Tensor) -> torch.Tensor: + """ + Calculates anaphoricity scores. + + Args: + x: tensor of shape [batch_size, n_ants, n_features] + + Returns: + tensor of shape [batch_size, n_ants] + """ + x = self.out(self.hidden(x)) + x = x.squeeze(2) + + if not self.predict_singletons: + return x + + # because sometimes we only have the first 49 anaphoricities + start = x @ self.start_map.weight[:,:x.shape[1]].T + return x, start + + @staticmethod + def _get_pair_matrix(mentions_batch: torch.Tensor, + pw_batch: torch.Tensor, + top_mentions: torch.Tensor) -> torch.Tensor: + """ + Builds the matrix used as input for AnaphoricityScorer. + + Args: + all_mentions (torch.Tensor): [n_mentions, mention_emb], + all the valid mentions of the document, + can be on a different device + mentions_batch (torch.Tensor): [batch_size, mention_emb], + the mentions of the current batch, + is expected to be on the current device + pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb], + pairwise features of the current batch, + is expected to be on the current device + top_indices_batch (torch.Tensor): [batch_size, n_ants], + indices of antecedents of each mention + + Returns: + torch.Tensor: [batch_size, n_ants, pair_emb] + """ + emb_size = mentions_batch.shape[1] + n_ants = pw_batch.shape[1] + + a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size) + b_mentions = top_mentions + similarity = a_mentions * b_mentions + + out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2) + return out diff --git a/stanza/stanza/models/coref/cluster_checker.py b/stanza/stanza/models/coref/cluster_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ab46423edd6529fb90ac7401977598c0f05abb --- /dev/null +++ b/stanza/stanza/models/coref/cluster_checker.py @@ -0,0 +1,230 @@ +""" Describes ClusterChecker, a class used to retrieve LEA scores. +See aclweb.org/anthology/P16-1060.pdf. """ + +from typing import Hashable, List, Tuple + +from stanza.models.coref.const import EPSILON +import numpy as np + +import math +import logging + +logger = logging.getLogger('stanza') + + +class ClusterChecker: + """ Collects information on gold and predicted clusters across documents. + Can be used to retrieve weighted LEA-score for them. + """ + def __init__(self): + self._lea_precision = 0.0 + self._lea_recall = 0.0 + self._lea_precision_weighting = 0.0 + self._lea_recall_weighting = 0.0 + self._num_preds = 0.0 + + # muc + self._muc_precision = 0.0 + self._muc_recall = 0.0 + + # b3 + self._b3_precision = 0.0 + self._b3_recall = 0.0 + + # ceafe + self._ceafe_precision = 0.0 + self._ceafe_recall = 0.0 + + @staticmethod + def _f1(p,r): + return (p * r) / (p+r + EPSILON) * 2 + + def add_predictions(self, + gold_clusters: List[List[Hashable]], + pred_clusters: List[List[Hashable]]): + """ + Calculates LEA for the document's clusters and stores them to later + output weighted LEA across documents. + + Returns: + LEA score for the document as a tuple of (f1, precision, recall) + """ + + # if len(gold_clusters) == 0: + # breakpoint() + + self._num_preds += 1 + + recall, r_weight = ClusterChecker._lea(gold_clusters, pred_clusters) + precision, p_weight = ClusterChecker._lea(pred_clusters, gold_clusters) + + self._muc_recall += ClusterChecker._muc(gold_clusters, pred_clusters) + self._muc_precision += ClusterChecker._muc(pred_clusters, gold_clusters) + + self._b3_recall += ClusterChecker._b3(gold_clusters, pred_clusters) + self._b3_precision += ClusterChecker._b3(pred_clusters, gold_clusters) + + ceafe_precision, ceafe_recall = ClusterChecker._ceafe(pred_clusters, gold_clusters) + if math.isnan(ceafe_precision) and len(gold_clusters) > 0: + # because our model predicted no clusters + ceafe_precision = 0.0 + + self._ceafe_precision += ceafe_precision + self._ceafe_recall += ceafe_recall + + self._lea_recall += recall + self._lea_recall_weighting += r_weight + self._lea_precision += precision + self._lea_precision_weighting += p_weight + + doc_precision = precision / (p_weight + EPSILON) + doc_recall = recall / (r_weight + EPSILON) + doc_f1 = (doc_precision * doc_recall) \ + / (doc_precision + doc_recall + EPSILON) * 2 + return doc_f1, doc_precision, doc_recall + + @property + def bakeoff(self): + """ Get the F1 macroaverage score used by the bakeoff """ + return sum(self.mbc)/3 + + @property + def mbc(self): + """ Get the F1 average score of (muc, b3, ceafe) over docs """ + avg_precisions = [self._muc_precision, self._b3_precision, self._ceafe_precision] + avg_precisions = [i/(self._num_preds + EPSILON) for i in avg_precisions] + + avg_recalls = [self._muc_recall, self._b3_recall, self._ceafe_recall] + avg_recalls = [i/(self._num_preds + EPSILON) for i in avg_recalls] + + avg_f1s = [self._f1(p,r) for p,r in zip(avg_precisions, avg_recalls)] + + return avg_f1s + + @property + def total_lea(self): + """ Returns weighted LEA for all the documents as + (f1, precision, recall) """ + precision = self._lea_precision / (self._lea_precision_weighting + EPSILON) + recall = self._lea_recall / (self._lea_recall_weighting + EPSILON) + f1 = self._f1(precision, recall) + return f1, precision, recall + + @staticmethod + def _lea(key: List[List[Hashable]], + response: List[List[Hashable]]) -> Tuple[float, float]: + """ See aclweb.org/anthology/P16-1060.pdf. """ + response_clusters = [set(cluster) for cluster in response] + response_map = {mention: cluster + for cluster in response_clusters + for mention in cluster} + importances = [] + resolutions = [] + for entity in key: + size = len(entity) + if size == 1: # entities of size 1 are not annotated + continue + importances.append(size) + correct_links = 0 + for i in range(size): + for j in range(i + 1, size): + correct_links += int(entity[i] + in response_map.get(entity[j], {})) + resolutions.append(correct_links / (size * (size - 1) / 2)) + res = sum(imp * res for imp, res in zip(importances, resolutions)) + weight = sum(importances) + return res, weight + + @staticmethod + def _muc(key: List[List[Hashable]], + response: List[List[Hashable]]) -> float: + """ See aclweb.org/anthology/P16-1060.pdf. """ + + response_clusters = [set(cluster) for cluster in response] + response_map = {mention: cluster + for cluster in response_clusters + for mention in cluster} + + top = 0 # sum over k of |k_i| - response_partitions(|k_i|) + bottom = 0 # sum over k of |k_i| - 1 + + for entity in key: + S = len(entity) + # we need to figure the number of DIFFERENT clusters + # the response assigns to members of the entity; ideally + # this number is 1 (i.e. they are all assigned the same + # coref). + response_clusters = [response_map.get(i, None) for i in entity] + # and dedplicate + deduped = [] + for i in response_clusters: + if i == None: + deduped.append(i) + elif i not in deduped: + deduped.append(i) + # the "partitions" will then be size of the deduped list + p_k = len(deduped) + top += (S - p_k) + bottom += (S - 1) + + try: + return top/bottom + except ZeroDivisionError: + logger.warning("muc got a zero division error because the model predicted no spans!") + return 0 # +inf technically + + @staticmethod + def _b3(key: List[List[Hashable]], + response: List[List[Hashable]]) -> float: + """ See aclweb.org/anthology/P16-1060.pdf. """ + + response_clusters = [set(cluster) for cluster in response] + + top = 0 # sum over key and response of (|k intersect response|^2/|k|) + bottom = 0 # sum over k of |k_i| + + for entity in key: + bottom += len(entity) + entity = set(entity) + + for res_entity in response_clusters: + top += (len(entity.intersection(res_entity))**2)/len(entity) + + try: + return top/bottom + except ZeroDivisionError: + logger.warning("b3 got a zero division error because the model predicted no spans!") + return 0 # +inf technically + + + + @staticmethod + def _phi4(c1, c2): + return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) + + @staticmethod + def _ceafe(clusters: List[List[Hashable]], gold_clusters: List[List[Hashable]]): + """ see https://github.com/ufal/corefud-scorer/blob/main/coval/eval/evaluator.py """ + + try: + from scipy.optimize import linear_sum_assignment + except ImportError: + raise ImportError("To perform CEAF scoring, please install scipy via `pip install scipy` for the Kuhn-Munkres linear assignment scheme.") + + clusters = [c for c in clusters] + scores = np.zeros((len(gold_clusters), len(clusters))) + for i in range(len(gold_clusters)): + for j in range(len(clusters)): + scores[i, j] = ClusterChecker._phi4(gold_clusters[i], clusters[j]) + row_ind, col_ind = linear_sum_assignment(-scores) + similarity = scores[row_ind, col_ind].sum() + + # precision, recall + try: + prec = similarity/len(clusters) + except ZeroDivisionError: + logger.warning("ceafe got a zero division error because the model predicted no spans!") + prec = 0 + recc = similarity/len(gold_clusters) + return prec, recc + diff --git a/stanza/stanza/models/coref/conll.py b/stanza/stanza/models/coref/conll.py new file mode 100644 index 0000000000000000000000000000000000000000..abd1762955aa8f59dfcb1c11af2c4061cce4fe58 --- /dev/null +++ b/stanza/stanza/models/coref/conll.py @@ -0,0 +1,105 @@ +""" Contains functions to produce conll-formatted output files with +predicted spans and their clustering """ + +from collections import defaultdict +from contextlib import contextmanager +import os +from typing import List, TextIO + +from stanza.models.coref.config import Config +from stanza.models.coref.const import Doc, Span + + +# pylint: disable=too-many-locals +def write_conll(doc: Doc, + clusters: List[List[Span]], + heads: List[int], + f_obj: TextIO): + """ Writes span/cluster information to f_obj, which is assumed to be a file + object open for writing """ + placeholder = list("\t_" * 7) + # the nth token needs to be a number + placeholder[9] = "0" + placeholder = "".join(placeholder) + doc_id = doc["document_id"].replace("-", "_").replace("/", "_").replace(".","_") + words = doc["cased_words"] + part_id = doc["part_id"] + sents = doc["sent_id"] + + max_word_len = max(len(w) for w in words) + + starts = defaultdict(lambda: []) + ends = defaultdict(lambda: []) + single_word = defaultdict(lambda: []) + + for cluster_id, cluster in enumerate(clusters): + if len(heads[cluster_id]) != len(cluster): + # TODO debug this fact and why it occurs + # print(f"cluster {cluster_id} doesn't have the same number of elements for word and span levels, skipping...") + continue + for cluster_part, (start, end) in enumerate(cluster): + if end - start == 1: + single_word[start].append((cluster_part, cluster_id)) + else: + starts[start].append((cluster_part, cluster_id)) + ends[end - 1].append((cluster_part, cluster_id)) + + f_obj.write(f"# newdoc id = {doc_id}\n# global.Entity = eid-head\n") + + word_number = 0 + sent_id = 0 + for word_id, word in enumerate(words): + + cluster_info_lst = [] + for part, cluster_marker in starts[word_id]: + start, end = clusters[cluster_marker][part] + cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)}") + for part, cluster_marker in single_word[word_id]: + start, end = clusters[cluster_marker][part] + cluster_info_lst.append(f"(e{cluster_marker}-{min(heads[cluster_marker][part], end-start)})") + for part, cluster_marker in ends[word_id]: + cluster_info_lst.append(f"e{cluster_marker})") + + + # we need our clusters to be ordered such that the one that is closest the first change + # is listed last in the chains + def compare_sort(x): + split = x.split("-") + if len(split) > 1: + return int(split[-1].replace(")", "").strip()) + else: + # we want everything that's a closer to be first + return float("inf") + + cluster_info_lst = sorted(cluster_info_lst, key=compare_sort, reverse=True) + cluster_info = "".join(cluster_info_lst) if cluster_info_lst else "_" + + if word_id == 0 or sents[word_id] != sents[word_id - 1]: + f_obj.write(f"# sent_id = {doc_id}-{sent_id}\n") + word_number = 0 + sent_id += 1 + + if cluster_info != "_": + cluster_info = f"Entity={cluster_info}" + + f_obj.write(f"{word_id}\t{word}{placeholder}\t{cluster_info}\n") + + word_number += 1 + + f_obj.write("\n") + + +@contextmanager +def open_(config: Config, epochs: int, data_split: str): + """ Opens conll log files for writing in a safe way. """ + base_filename = f"{config.section}_{data_split}_e{epochs}" + conll_dir = config.conll_log_dir + kwargs = {"mode": "w", "encoding": "utf8"} + + os.makedirs(conll_dir, exist_ok=True) + + with open(os.path.join( # type: ignore + conll_dir, f"{base_filename}.gold.conll"), **kwargs) as gold_f: + with open(os.path.join( # type: ignore + conll_dir, f"{base_filename}.pred.conll"), **kwargs) as pred_f: + yield (gold_f, pred_f) diff --git a/stanza/stanza/models/coref/const.py b/stanza/stanza/models/coref/const.py new file mode 100644 index 0000000000000000000000000000000000000000..931eee12294a2684e101c46339ccda55dad097f3 --- /dev/null +++ b/stanza/stanza/models/coref/const.py @@ -0,0 +1,27 @@ +""" Contains type aliases for coref module """ + +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + + +EPSILON = 1e-7 +LARGE_VALUE = 1000 # used instead of inf due to bug #16762 in pytorch + +Doc = Dict[str, Any] +Span = Tuple[int, int] + + +@dataclass +class CorefResult: + coref_scores: torch.Tensor = None # [n_words, k + 1] + coref_y: torch.Tensor = None # [n_words, k + 1] + rough_y: torch.Tensor = None # [n_words, n_words] + + word_clusters: List[List[int]] = None + span_clusters: List[List[Span]] = None + + rough_scores: torch.Tensor = None # [n_words, n_words] + span_scores: torch.Tensor = None # [n_heads, n_words, 2] + span_y: Tuple[torch.Tensor, torch.Tensor] = None # [n_heads] x2 diff --git a/stanza/stanza/models/coref/coref_chain.py b/stanza/stanza/models/coref/coref_chain.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf7d17d4080ab0ae9fb58701d75d4044fa95fc7 --- /dev/null +++ b/stanza/stanza/models/coref/coref_chain.py @@ -0,0 +1,39 @@ +""" +Coref chain suitable for attaching to a Document after coref processing +""" + +# by not using namedtuple, we can use this object as output from the json module +# in the doc class as long as we wrap the encoder to print these out in dict() form +# CorefMention = namedtuple('CorefMention', ['sentence', 'start_word', 'end_word']) +class CorefMention: + def __init__(self, sentence, start_word, end_word): + self.sentence = sentence + self.start_word = start_word + self.end_word = end_word + +class CorefChain: + def __init__(self, index, mentions, representative_text, representative_index): + self.index = index + self.mentions = mentions + self.representative_text = representative_text + self.representative_index = representative_index + +class CorefAttachment: + def __init__(self, chain, is_start, is_end, is_representative): + self.chain = chain + self.is_start = is_start + self.is_end = is_end + self.is_representative = is_representative + + def to_json(self): + j = { + "index": self.chain.index, + "representative_text": self.chain.representative_text + } + if self.is_start: + j['is_start'] = True + if self.is_end: + j['is_end'] = True + if self.is_representative: + j['is_representative'] = True + return j diff --git a/stanza/stanza/models/coref/loss.py b/stanza/stanza/models/coref/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..020b2c56dd1f11fa5c7fe1beac24759dd4c9e8e6 --- /dev/null +++ b/stanza/stanza/models/coref/loss.py @@ -0,0 +1,37 @@ +""" Describes the loss function used to train the model, which is a weighted +sum of NLML and BCE losses. """ + +import torch + + +class CorefLoss(torch.nn.Module): + """ See the rationale for using NLML in Lee et al. 2017 + https://www.aclweb.org/anthology/D17-1018/ + The added weighted summand of BCE helps the model learn even after + converging on the NLML task. """ + + def __init__(self, bce_weight: float): + assert 0 <= bce_weight <= 1 + super().__init__() + self._bce_module = torch.nn.BCEWithLogitsLoss() + self._bce_weight = bce_weight + + def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch + input_: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """ Returns a weighted sum of two losses as a torch.Tensor """ + return (self._nlml(input_, target) + + self._bce(input_, target) * self._bce_weight) + + def _bce(self, + input_: torch.Tensor, + target: torch.Tensor) -> torch.Tensor: + """ For numerical stability, clamps the input before passing it to BCE. + """ + return self._bce_module(torch.clamp(input_, min=-50, max=50), target) + + @staticmethod + def _nlml(input_: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + gold = torch.logsumexp(input_ + torch.log(target), dim=1) + input_ = torch.logsumexp(input_, dim=1) + return (input_ - gold).mean() diff --git a/stanza/stanza/models/coref/model.py b/stanza/stanza/models/coref/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c668f4b1b980c4e0473d274037d3cce753d1e1 --- /dev/null +++ b/stanza/stanza/models/coref/model.py @@ -0,0 +1,784 @@ +""" see __init__.py """ + +from datetime import datetime +import dataclasses +import json +import logging +import os +import random +import re +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np # type: ignore +try: + import tomllib +except ImportError: + import tomli as tomllib +import torch +import transformers # type: ignore + +from pickle import UnpicklingError +import warnings + +from stanza.utils.get_tqdm import get_tqdm # type: ignore +tqdm = get_tqdm() + +from stanza.models.coref import bert, conll, utils +from stanza.models.coref.anaphoricity_scorer import AnaphoricityScorer +from stanza.models.coref.cluster_checker import ClusterChecker +from stanza.models.coref.config import Config +from stanza.models.coref.const import CorefResult, Doc +from stanza.models.coref.loss import CorefLoss +from stanza.models.coref.pairwise_encoder import PairwiseEncoder +from stanza.models.coref.rough_scorer import RoughScorer +from stanza.models.coref.span_predictor import SpanPredictor +from stanza.models.coref.utils import GraphNode +from stanza.models.coref.word_encoder import WordEncoder +from stanza.models.coref.dataset import CorefDataset +from stanza.models.coref.tokenizer_customization import * + +from stanza.models.common.bert_embedding import load_tokenizer +from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache +from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper + +logger = logging.getLogger('stanza') + +class CorefModel: # pylint: disable=too-many-instance-attributes + """Combines all coref modules together to find coreferent spans. + + Attributes: + config (coref.config.Config): the model's configuration, + see config.toml for the details + epochs_trained (int): number of epochs the model has been trained for + trainable (Dict[str, torch.nn.Module]): trainable submodules with their + names used as keys + training (bool): used to toggle train/eval modes + + Submodules (in the order of their usage in the pipeline): + tokenizer (transformers.AutoTokenizer) + bert (transformers.AutoModel) + we (WordEncoder) + rough_scorer (RoughScorer) + pw (PairwiseEncoder) + a_scorer (AnaphoricityScorer) + sp (SpanPredictor) + """ + def __init__(self, + epochs_trained: int = 0, + build_optimizers: bool = True, + config: Optional[dict] = None, + foundation_cache=None): + """ + A newly created model is set to evaluation mode. + + Args: + config_path (str): the path to the toml file with the configuration + section (str): the selected section of the config file + epochs_trained (int): the number of epochs finished + (useful for warm start) + """ + if config is None: + raise ValueError("Cannot create a model without a config") + self.config = config + self.epochs_trained = epochs_trained + self._docs: Dict[str, List[Doc]] = {} + self._build_model(foundation_cache) + + self.optimizers = {} + self.schedulers = {} + + if build_optimizers: + self._build_optimizers() + self._set_training(False) + + # final coreference resolution score + self._coref_criterion = CorefLoss(self.config.bce_loss_weight) + # score simply for the top-k choices out of the rough scorer + self._rough_criterion = CorefLoss(0) + # exact span matches + self._span_criterion = torch.nn.CrossEntropyLoss(reduction="sum") + + @property + def training(self) -> bool: + """ Represents whether the model is in the training mode """ + return self._training + + @training.setter + def training(self, new_value: bool): + if self._training is new_value: + return + self._set_training(new_value) + + # ========================================================== Public methods + + @torch.no_grad() + def evaluate(self, + data_split: str = "dev", + word_level_conll: bool = False, + eval_lang: Optional[str] = None + ) -> Tuple[float, Tuple[float, float, float]]: + """ Evaluates the modes on the data split provided. + + Args: + data_split (str): one of 'dev'/'test'/'train' + word_level_conll (bool): if True, outputs conll files on word-level + eval_lang (str): which language to evaluate + + Returns: + mean loss + span-level LEA: f1, precision, recal + """ + self.training = False + w_checker = ClusterChecker() + s_checker = ClusterChecker() + try: + data_split_data = f"{data_split}_data" + data_path = self.config.__dict__[data_split_data] + docs = self._get_docs(data_path) + except FileNotFoundError as e: + raise FileNotFoundError("Unable to find data split %s at file %s" % (data_split_data, data_path)) from e + running_loss = 0.0 + s_correct = 0 + s_total = 0 + + with conll.open_(self.config, self.epochs_trained, data_split) \ + as (gold_f, pred_f): + pbar = tqdm(docs, unit="docs", ncols=0) + for doc in pbar: + if eval_lang and doc.get("lang", "") != eval_lang: + # skip that document, only used for ablation where we only + # want to test evaluation on one language + continue + + res = self.run(doc) + + if (res.coref_y.argmax(dim=1) == 1).all(): + logger.warning(f"EVAL: skipping document with no corefs...") + continue + + running_loss += self._coref_criterion(res.coref_scores, res.coref_y).item() + + if res.span_y: + pred_starts = res.span_scores[:, :, 0].argmax(dim=1) + pred_ends = res.span_scores[:, :, 1].argmax(dim=1) + s_correct += ((res.span_y[0] == pred_starts) * (res.span_y[1] == pred_ends)).sum().item() + s_total += len(pred_starts) + + + if word_level_conll: + raise NotImplementedError("We now write Conll-U conforming to UDCoref, which means that the span_clusters annotations will have headword info. word_level option is meaningless.") + else: + conll.write_conll(doc, doc["span_clusters"], doc["word_clusters"], gold_f) + conll.write_conll(doc, res.span_clusters, res.word_clusters, pred_f) + + w_checker.add_predictions(doc["word_clusters"], res.word_clusters) + w_lea = w_checker.total_lea + + s_checker.add_predictions(doc["span_clusters"], res.span_clusters) + s_lea = s_checker.total_lea + + del res + + pbar.set_description( + f"{data_split}:" + f" | WL: " + f" loss: {running_loss / (pbar.n + 1):<.5f}," + f" f1: {w_lea[0]:.5f}," + f" p: {w_lea[1]:.5f}," + f" r: {w_lea[2]:<.5f}" + f" | SL: " + f" sa: {s_correct / s_total:<.5f}," + f" f1: {s_lea[0]:.5f}," + f" p: {s_lea[1]:.5f}," + f" r: {s_lea[2]:<.5f}" + ) + logger.info(f"CoNLL-2012 3-Score Average : {w_checker.bakeoff:.5f}") + + return (running_loss / len(docs), *s_checker.total_lea, *w_checker.total_lea, *s_checker.mbc, *w_checker.mbc, w_checker.bakeoff, s_checker.bakeoff) + + def load_weights(self, + path: Optional[str] = None, + ignore: Optional[Set[str]] = None, + map_location: Optional[str] = None, + noexception: bool = False) -> None: + """ + Loads pretrained weights of modules saved in a file located at path. + If path is None, the last saved model with current configuration + in save_dir is loaded. + Assumes files are named like {configuration}_(e{epoch}_{time})*.pt. + """ + if path is None: + # pattern = rf"{self.config.save_name}_\(e(\d+)_[^()]*\).*\.pt" + # tries to load the last checkpoint in the same dir + pattern = rf"{self.config.save_name}.*?\.checkpoint\.pt" + files = [] + os.makedirs(self.config.save_dir, exist_ok=True) + for f in os.listdir(self.config.save_dir): + match_obj = re.match(pattern, f) + if match_obj: + files.append(f) + if not files: + if noexception: + logger.debug("No weights have been loaded", flush=True) + return + raise OSError(f"No weights found in {self.config.save_dir}!") + path = sorted(files)[-1] + path = os.path.join(self.config.save_dir, path) + + if map_location is None: + map_location = self.config.device + logger.debug(f"Loading from {path}...") + try: + state_dicts = torch.load(path, map_location=map_location, weights_only=True) + except UnpicklingError: + state_dicts = torch.load(path, map_location=map_location, weights_only=False) + warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.") + self.epochs_trained = state_dicts.pop("epochs_trained", 0) + # just ignore a config in the model, since we should already have one + # TODO: some config elements may be fixed parameters of the model, + # such as the dimensions of the head, + # so we would want to use the ones from the config even if the + # user created a weird shaped model + config = state_dicts.pop("config", {}) + self.load_state_dicts(state_dicts, ignore) + + def load_state_dicts(self, + state_dicts: dict, + ignore: Optional[Set[str]] = None): + """ + Process the dictionaries from the save file + + Loads the weights into the tensors of this model + May also have optimizer and/or schedule state + """ + for key, state_dict in state_dicts.items(): + logger.debug("Loading state: %s", key) + if not ignore or key not in ignore: + if key.endswith("_optimizer"): + self.optimizers[key].load_state_dict(state_dict) + elif key.endswith("_scheduler"): + self.schedulers[key].load_state_dict(state_dict) + elif key == "bert_lora": + assert self.config.lora, "Unable to load state dict of LoRA model into model initialized without LoRA!" + self.bert = load_peft_wrapper(self.bert, state_dict, vars(self.config), logger, self.peft_name) + else: + self.trainable[key].load_state_dict(state_dict, strict=False) + logger.debug(f"Loaded {key}") + if self.config.log_norms: + self.log_norms() + + def build_doc(self, doc: dict) -> dict: + filter_func = TOKENIZER_FILTERS.get(self.config.bert_model, + lambda _: True) + token_map = TOKENIZER_MAPS.get(self.config.bert_model, {}) + + word2subword = [] + subwords = [] + word_id = [] + for i, word in enumerate(doc["cased_words"]): + tokenized_word = (token_map[word] + if word in token_map + else self.tokenizer.tokenize(word)) + tokenized_word = list(filter(filter_func, tokenized_word)) + word2subword.append((len(subwords), len(subwords) + len(tokenized_word))) + subwords.extend(tokenized_word) + word_id.extend([i] * len(tokenized_word)) + doc["word2subword"] = word2subword + doc["subwords"] = subwords + doc["word_id"] = word_id + + doc["head2span"] = [] + if "speaker" not in doc: + doc["speaker"] = ["_" for _ in doc["cased_words"]] + doc["word_clusters"] = [] + doc["span_clusters"] = [] + + return doc + + + @staticmethod + def load_model(path: str, + map_location: str = "cpu", + ignore: Optional[Set[str]] = None, + config_update: Optional[dict] = None, + foundation_cache = None): + if not path: + raise FileNotFoundError("coref model got an invalid path |%s|" % path) + if not os.path.exists(path): + raise FileNotFoundError("coref model file %s not found" % path) + try: + state_dicts = torch.load(path, map_location=map_location, weights_only=True) + except UnpicklingError: + state_dicts = torch.load(path, map_location=map_location, weights_only=False) + warnings.warn("The saved coref model has an old format using Config instead of the Config mapped to dict to store weights. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the coref model using this version ASAP.") + epochs_trained = state_dicts.pop("epochs_trained", 0) + config = state_dicts.pop('config', None) + if config is None: + raise ValueError("Cannot load this format model without config in the dicts") + if isinstance(config, dict): + config = Config(**config) + if config_update: + for key, value in config_update.items(): + setattr(config, key, value) + model = CorefModel(config=config, build_optimizers=False, + epochs_trained=epochs_trained, foundation_cache=foundation_cache) + model.load_state_dicts(state_dicts, ignore) + return model + + + def run(self, # pylint: disable=too-many-locals + doc: Doc, + ) -> CorefResult: + """ + This is a massive method, but it made sense to me to not split it into + several ones to let one see the data flow. + + Args: + doc (Doc): a dictionary with the document data. + + Returns: + CorefResult (see const.py) + """ + # Encode words with bert + # words [n_words, span_emb] + # cluster_ids [n_words] + words, cluster_ids = self.we(doc, self._bertify(doc)) + + # Obtain bilinear scores and leave only top-k antecedents for each word + # top_rough_scores [n_words, n_ants] + # top_indices [n_words, n_ants] + top_rough_scores, top_indices, rough_scores = self.rough_scorer(words) + + # Get pairwise features [n_words, n_ants, n_pw_features] + pw = self.pw(top_indices, doc) + + batch_size = self.config.a_scoring_batch_size + a_scores_lst: List[torch.Tensor] = [] + + for i in range(0, len(words), batch_size): + pw_batch = pw[i:i + batch_size] + words_batch = words[i:i + batch_size] + top_indices_batch = top_indices[i:i + batch_size] + top_rough_scores_batch = top_rough_scores[i:i + batch_size] + + # a_scores_batch [batch_size, n_ants] + a_scores_batch = self.a_scorer( + top_mentions=words[top_indices_batch], mentions_batch=words_batch, + pw_batch=pw_batch, top_rough_scores_batch=top_rough_scores_batch + ) + a_scores_lst.append(a_scores_batch) + + res = CorefResult() + + # coref_scores [n_spans, n_ants] + res.coref_scores = torch.cat(a_scores_lst, dim=0) + + res.coref_y = self._get_ground_truth( + cluster_ids, top_indices, (top_rough_scores > float("-inf")), + self.config.clusters_starts_are_singletons, + self.config.singletons) + + res.word_clusters = self._clusterize(doc, res.coref_scores, top_indices, + self.config.singletons) + + res.span_scores, res.span_y = self.sp.get_training_data(doc, words) + + if not self.training: + res.span_clusters = self.sp.predict(doc, words, res.word_clusters) + + return res + + def save_weights(self, save_path=None, save_optimizers=True): + """ Saves trainable models as state dicts. """ + to_save: List[Tuple[str, Any]] = \ + [(key, value) for key, value in self.trainable.items() + if (self.config.bert_finetune and not self.config.lora) or key != "bert"] + if save_optimizers: + to_save.extend(self.optimizers.items()) + to_save.extend(self.schedulers.items()) + + time = datetime.strftime(datetime.now(), "%Y.%m.%d_%H.%M") + if save_path is None: + save_path = os.path.join(self.config.save_dir, + f"{self.config.save_name}" + f"_e{self.epochs_trained}_{time}.pt") + savedict = {name: module.state_dict() for name, module in to_save} + if self.config.lora: + # so that this dependency remains optional + from peft import get_peft_model_state_dict + savedict["bert_lora"] = get_peft_model_state_dict(self.bert, adapter_name="coref") + savedict["epochs_trained"] = self.epochs_trained # type: ignore + # save as a dictionary because the weights_only=True load option + # doesn't allow for arbitrary @dataclass configs + savedict["config"] = dataclasses.asdict(self.config) + save_dir = os.path.split(save_path)[0] + if save_dir: + os.makedirs(save_dir, exist_ok=True) + torch.save(savedict, save_path) + + def log_norms(self): + lines = ["NORMS FOR MODEL PARAMTERS"] + for t_name, trainable in self.trainable.items(): + for name, param in trainable.named_parameters(): + if param.requires_grad: + lines.append(" %s: %s %.6g (%d)" % (t_name, name, torch.norm(param).item(), param.numel())) + logger.info("\n".join(lines)) + + + def train(self, log=False): + """ + Trains all the trainable blocks in the model using the config provided. + + log: whether or not to log using wandb + skip_lang: str if we want to skip training this language (used for ablation) + """ + + if log: + import wandb + wandb.watch((self.bert, self.pw, + self.a_scorer, self.we, + self.rough_scorer, self.sp)) + + docs = self._get_docs(self.config.train_data) + docs_ids = list(range(len(docs))) + avg_spans = docs.avg_span + + best_f1 = None + for epoch in range(self.epochs_trained, self.config.train_epochs): + self.training = True + if self.config.log_norms: + self.log_norms() + running_c_loss = 0.0 + running_s_loss = 0.0 + random.shuffle(docs_ids) + pbar = tqdm(docs_ids, unit="docs", ncols=0) + for doc_indx, doc_id in enumerate(pbar): + doc = docs[doc_id] + + # skip very long documents during training time + if len(doc["subwords"]) > 5000: + continue + + for optim in self.optimizers.values(): + optim.zero_grad() + + res = self.run(doc) + + c_loss = self._coref_criterion(res.coref_scores, res.coref_y) + + if res.span_y: + s_loss = (self._span_criterion(res.span_scores[:, :, 0], res.span_y[0]) + + self._span_criterion(res.span_scores[:, :, 1], res.span_y[1])) / avg_spans / 2 + else: + s_loss = torch.zeros_like(c_loss) + + del res + + (c_loss + s_loss).backward() + + running_c_loss += c_loss.item() + running_s_loss += s_loss.item() + + # log every 100 docs + if log and doc_indx % 100 == 0: + wandb.log({'train_c_loss': c_loss.item(), + 'train_s_loss': s_loss.item()}) + + + del c_loss, s_loss + + for optim in self.optimizers.values(): + optim.step() + for scheduler in self.schedulers.values(): + scheduler.step() + + pbar.set_description( + f"Epoch {epoch + 1}:" + f" {doc['document_id']:26}" + f" c_loss: {running_c_loss / (pbar.n + 1):<.5f}" + f" s_loss: {running_s_loss / (pbar.n + 1):<.5f}" + ) + + self.epochs_trained += 1 + scores = self.evaluate() + prev_best_f1 = best_f1 + if log: + wandb.log({'dev_score': scores[1]}) + wandb.log({'dev_bakeoff': scores[-1]}) + + if best_f1 is None or scores[1] > best_f1: + + if best_f1 is None: + logger.info("Saving new best model: F1 %.4f", scores[1]) + else: + logger.info("Saving new best model: F1 %.4f > %.4f", scores[1], best_f1) + best_f1 = scores[1] + if self.config.save_name.endswith(".pt"): + save_path = os.path.join(self.config.save_dir, + f"{self.config.save_name}") + else: + save_path = os.path.join(self.config.save_dir, + f"{self.config.save_name}.pt") + self.save_weights(save_path, save_optimizers=False) + if self.config.save_each_checkpoint: + self.save_weights() + else: + if self.config.save_name.endswith(".pt"): + checkpoint_path = os.path.join(self.config.save_dir, + f"{self.config.save_name[:-3]}.checkpoint.pt") + else: + checkpoint_path = os.path.join(self.config.save_dir, + f"{self.config.save_name}.checkpoint.pt") + self.save_weights(checkpoint_path) + if prev_best_f1 is not None and prev_best_f1 != best_f1: + logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f\nPrevious best F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1, prev_best_f1) + else: + logger.info("Epoch %d finished.\nSentence F1 %.5f p %.5f r %.5f\nBest F1 %.5f", self.epochs_trained, scores[1], scores[2], scores[3], best_f1) + + # ========================================================= Private methods + + def _bertify(self, doc: Doc) -> torch.Tensor: + all_batches = bert.get_subwords_batches(doc, self.config, self.tokenizer) + + # we index the batches n at a time to prevent oom + result = [] + for i in range(0, all_batches.shape[0], 1024): + subwords_batches = all_batches[i:i+1024] + + special_tokens = np.array([self.tokenizer.cls_token_id, + self.tokenizer.sep_token_id, + self.tokenizer.pad_token_id, + self.tokenizer.eos_token_id]) + subword_mask = ~(np.isin(subwords_batches, special_tokens)) + + subwords_batches_tensor = torch.tensor(subwords_batches, + device=self.config.device, + dtype=torch.long) + subword_mask_tensor = torch.tensor(subword_mask, + device=self.config.device) + + # Obtain bert output for selected batches only + attention_mask = (subwords_batches != self.tokenizer.pad_token_id) + if "t5" in self.config.bert_model: + out = self.bert.encoder( + input_ids=subwords_batches_tensor, + attention_mask=torch.tensor( + attention_mask, device=self.config.device)) + else: + out = self.bert( + subwords_batches_tensor, + attention_mask=torch.tensor( + attention_mask, device=self.config.device)) + + out = out['last_hidden_state'] + # [n_subwords, bert_emb] + result.append(out[subword_mask_tensor]) + + # stack returns and return + return torch.cat(result) + + def _build_model(self, foundation_cache): + if hasattr(self.config, 'lora') and self.config.lora: + self.bert, self.tokenizer, peft_name = load_bert_with_peft(self.config.bert_model, "coref", foundation_cache) + # vars() converts a dataclass to a dict, used for being able to index things like args["lora_*"] + self.bert = build_peft_wrapper(self.bert, vars(self.config), logger, adapter_name=peft_name) + self.peft_name = peft_name + else: + if self.config.bert_finetune: + logger.debug("Coref model requested a finetuned transformer; we are not using the foundation model cache to prevent we accidentally leak the finetuning weights elsewhere.") + foundation_cache = NoTransformerFoundationCache(foundation_cache) + self.bert, self.tokenizer = load_bert(self.config.bert_model, foundation_cache) + + base_bert_name = self.config.bert_model.split("/")[-1] + tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {}) + if tokenizer_kwargs: + logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}") + # we just downloaded the tokenizer, so for simplicity, we don't make another request to HF + self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True) + + if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora): + self.bert = self.bert.train() + + self.bert = self.bert.to(self.config.device) + self.pw = PairwiseEncoder(self.config).to(self.config.device) + + bert_emb = self.bert.config.hidden_size + pair_emb = bert_emb * 3 + self.pw.shape + + # pylint: disable=line-too-long + self.a_scorer = AnaphoricityScorer(pair_emb, self.config).to(self.config.device) + self.we = WordEncoder(bert_emb, self.config).to(self.config.device) + self.rough_scorer = RoughScorer(bert_emb, self.config).to(self.config.device) + self.sp = SpanPredictor(bert_emb, self.config.sp_embedding_size).to(self.config.device) + + self.trainable: Dict[str, torch.nn.Module] = { + "bert": self.bert, "we": self.we, + "rough_scorer": self.rough_scorer, + "pw": self.pw, "a_scorer": self.a_scorer, + "sp": self.sp + } + + def _build_optimizers(self): + n_docs = len(self._get_docs(self.config.train_data)) + self.optimizers: Dict[str, torch.optim.Optimizer] = {} + self.schedulers: Dict[str, torch.optim.lr_scheduler.LRScheduler] = {} + + if not getattr(self.config, 'lora', False): + for param in self.bert.parameters(): + param.requires_grad = self.config.bert_finetune + + if self.config.bert_finetune: + logger.debug("Making bert optimizer with LR of %f", self.config.bert_learning_rate) + self.optimizers["bert_optimizer"] = torch.optim.Adam( + self.bert.parameters(), lr=self.config.bert_learning_rate + ) + start_finetuning = int(n_docs * self.config.bert_finetune_begin_epoch) + if start_finetuning > 0: + logger.info("Will begin finetuning transformer at iteration %d", start_finetuning) + zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizers["bert_optimizer"], factor=0, total_iters=start_finetuning) + warmup_scheduler = transformers.get_linear_schedule_with_warmup( + self.optimizers["bert_optimizer"], + start_finetuning, n_docs * self.config.train_epochs - start_finetuning) + self.schedulers["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR( + self.optimizers["bert_optimizer"], + schedulers=[zero_scheduler, warmup_scheduler], + milestones=[start_finetuning]) + + # Must ensure the same ordering of parameters between launches + modules = sorted((key, value) for key, value in self.trainable.items() + if key != "bert") + params = [] + for _, module in modules: + for param in module.parameters(): + param.requires_grad = True + params.append(param) + + self.optimizers["general_optimizer"] = torch.optim.Adam( + params, lr=self.config.learning_rate) + self.schedulers["general_scheduler"] = \ + transformers.get_linear_schedule_with_warmup( + self.optimizers["general_optimizer"], + 0, n_docs * self.config.train_epochs + ) + + def _clusterize(self, doc: Doc, scores: torch.Tensor, top_indices: torch.Tensor, + singletons: bool = True): + if singletons: + antecedents = scores[:,1:].argmax(dim=1) - 1 + # set the dummy values to -1, so that they are not coref to themselves + is_start = (scores[:, :2].argmax(dim=1) == 0) + else: + antecedents = scores.argmax(dim=1) - 1 + + not_dummy = antecedents >= 0 + coref_span_heads = torch.arange(0, len(scores), device=not_dummy.device)[not_dummy] + antecedents = top_indices[coref_span_heads, antecedents[not_dummy]] + + nodes = [GraphNode(i) for i in range(len(doc["cased_words"]))] + for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()): + nodes[i].link(nodes[j]) + assert nodes[i] is not nodes[j] + + visited = {} + + clusters = [] + for node in nodes: + if len(node.links) > 0 and not node.visited: + cluster = [] + stack = [node] + while stack: + current_node = stack.pop() + current_node.visited = True + cluster.append(current_node.id) + stack.extend(link for link in current_node.links if not link.visited) + assert len(cluster) > 1 + for i in cluster: + visited[i] = True + clusters.append(sorted(cluster)) + + if singletons: + # go through the is_start nodes; if no clusters contain that node + # i.e. visited[i] == False, we add it as a singleton + for indx, i in enumerate(is_start): + if i and not visited.get(indx, False): + clusters.append([indx]) + + return sorted(clusters) + + def _get_docs(self, path: str) -> List[Doc]: + if path not in self._docs: + self._docs[path] = CorefDataset(path, self.config, self.tokenizer) + return self._docs[path] + + @staticmethod + def _get_ground_truth(cluster_ids: torch.Tensor, + top_indices: torch.Tensor, + valid_pair_map: torch.Tensor, + cluster_starts: bool, + singletons:bool = True) -> torch.Tensor: + """ + Args: + cluster_ids: tensor of shape [n_words], containing cluster indices + for each word. Non-gold words have cluster id of zero. + top_indices: tensor of shape [n_words, n_ants], + indices of antecedents of each word + valid_pair_map: boolean tensor of shape [n_words, n_ants], + whether for pair at [i, j] (i-th word and j-th word) + j < i is True + + Returns: + tensor of shape [n_words, n_ants + 1] (dummy added), + containing 1 at position [i, j] if i-th and j-th words corefer. + """ + y = cluster_ids[top_indices] * valid_pair_map # [n_words, n_ants] + y[y == 0] = -1 # -1 for non-gold words + y = utils.add_dummy(y) # [n_words, n_cands + 1] + + if singletons: + if not cluster_starts: + unique, counts = cluster_ids.unique(return_counts=True) + singleton_clusters = unique[(counts == 1) & (unique != 0)] + first_corefs = [(cluster_ids == i).nonzero().flatten()[0] for i in singleton_clusters] + if len(first_corefs) > 0: + first_coref = torch.stack(first_corefs) + else: + first_coref = torch.tensor([]).to(cluster_ids.device).long() + else: + # I apologize for this abuse of everything that's good about PyTorch. + # in essence, this line finds the INDEX of FIRST OCCURENCE of each NON-ZERO value + # from cluster_ids. We need this information because we use it to mark the + # special "is-start-of-ref" marker used to detect singletons. + first_coref = (cluster_ids == + cluster_ids.unique().sort().values[1:].unsqueeze(1) + ).float().topk(k=1, dim=1).indices.squeeze() + y = (y == cluster_ids.unsqueeze(1)) # True if coreferent + # For all rows with no gold antecedents setting dummy to True + y[y.sum(dim=1) == 0, 0] = True + + if singletons: + # add another dummy for first coref + y = utils.add_dummy(y) # [n_words, n_cands + 2] + # for all rows that's a first coref, setting its dummy to True and unset the + # non-coref dummy to false + y[first_coref, 0] = True + y[first_coref, 1] = False + return y.to(torch.float) + + @staticmethod + def _load_config(config_path: str, + section: str) -> Config: + with open(config_path, "rb") as fin: + config = tomllib.load(fin) + default_section = config["DEFAULT"] + current_section = config[section] + unknown_keys = (set(current_section.keys()) + - set(default_section.keys())) + if unknown_keys: + raise ValueError(f"Unexpected config keys: {unknown_keys}") + return Config(section, **{**default_section, **current_section}) + + def _set_training(self, value: bool): + self._training = value + for module in self.trainable.values(): + module.train(self._training) + diff --git a/stanza/stanza/models/depparse/__init__.py b/stanza/stanza/models/depparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/models/depparse/scorer.py b/stanza/stanza/models/depparse/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..f7df9efdaafa1c68923f7fc3ee0b960429c44b04 --- /dev/null +++ b/stanza/stanza/models/depparse/scorer.py @@ -0,0 +1,60 @@ +""" +Utils and wrappers for scoring parsers. +""" + +from collections import Counter +import logging + +from stanza.models.common.utils import ud_scores + +logger = logging.getLogger('stanza') + +def score_named_dependencies(pred_doc, gold_doc): + if len(pred_doc.sentences) != len(gold_doc.sentences): + logger.warning("Not evaluating individual dependency F1 on accound of document length mismatch") + return + for sent_idx, (x, y) in enumerate(zip(pred_doc.sentences, gold_doc.sentences)): + if len(x.words) != len(y.words): + logger.warning("Not evaluating individual dependency F1 on accound of sentence length mismatch") + return + + tp = Counter() + fp = Counter() + fn = Counter() + for pred_sentence, gold_sentence in zip(pred_doc.sentences, gold_doc.sentences): + for pred_word, gold_word in zip(pred_sentence.words, gold_sentence.words): + if pred_word.head == gold_word.head and pred_word.deprel == gold_word.deprel: + tp[gold_word.deprel] = tp[gold_word.deprel] + 1 + else: + fn[gold_word.deprel] = fn[gold_word.deprel] + 1 + fp[pred_word.deprel] = fp[pred_word.deprel] + 1 + + labels = sorted(set(tp.keys()).union(fp.keys()).union(fn.keys())) + max_len = max(len(x) for x in labels) + log_lines = [] + log_line_fmt = "%" + str(max_len) + "s: p %.4f r %.4f f1 %.4f (%d actual)" + for label in labels: + if tp[label] == 0: + precision = 0 + recall = 0 + f1 = 0 + else: + precision = tp[label] / (tp[label] + fp[label]) + recall = tp[label] / (tp[label] + fn[label]) + f1 = 2 * (precision * recall) / (precision + recall) + log_lines.append(log_line_fmt % (label, precision, recall, f1, tp[label] + fn[label])) + logger.info("F1 scores for each dependency:\n Note that unlabeled attachment errors hurt the labeled attachment scores\n%s" % "\n".join(log_lines)) + +def score(system_conllu_file, gold_conllu_file, verbose=True): + """ Wrapper for UD parser scorer. """ + evaluation = ud_scores(gold_conllu_file, system_conllu_file) + el = evaluation['LAS'] + p = el.precision + r = el.recall + f = el.f1 + if verbose: + scores = [evaluation[k].f1 * 100 for k in ['LAS', 'MLAS', 'BLEX']] + logger.info("LAS\tMLAS\tBLEX") + logger.info("{:.2f}\t{:.2f}\t{:.2f}".format(*scores)) + return p, r, f + diff --git a/stanza/stanza/models/depparse/trainer.py b/stanza/stanza/models/depparse/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..46e70befb975e1822e5ccfc081805185f5de38b9 --- /dev/null +++ b/stanza/stanza/models/depparse/trainer.py @@ -0,0 +1,250 @@ +""" +A trainer class to handle training and testing of models. +""" + +import copy +import sys +import logging +import torch +from torch import nn + +try: + import transformers +except ImportError: + pass + +from stanza.models.common.trainer import Trainer as BaseTrainer +from stanza.models.common import utils, loss +from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, NoTransformerFoundationCache +from stanza.models.common.chuliu_edmonds import chuliu_edmonds_one_root +from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper +from stanza.models.depparse.model import Parser +from stanza.models.pos.vocab import MultiVocab + +logger = logging.getLogger('stanza') + +def unpack_batch(batch, device): + """ Unpack a batch from the data loader. """ + inputs = [b.to(device) if b is not None else None for b in batch[:11]] + orig_idx = batch[11] + word_orig_idx = batch[12] + sentlens = batch[13] + wordlens = batch[14] + text = batch[15] + return inputs, orig_idx, word_orig_idx, sentlens, wordlens, text + +class Trainer(BaseTrainer): + """ A trainer for training models. """ + def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, + device=None, foundation_cache=None, ignore_model_config=False, reset_history=False): + self.global_step = 0 + self.last_best_step = 0 + self.dev_score_history = [] + + orig_args = copy.deepcopy(args) + # whether the training is in primary or secondary stage + # during FT (loading weights), etc., the training is considered to be in "secondary stage" + # during this time, we (optionally) use a different set of optimizers than that during "primary stage". + # + # Regardless, we use TWO SETS of optimizers; once primary converges, we switch to secondary + + if model_file is not None: + # load everything from file + self.load(model_file, pretrain, args, foundation_cache, device) + + if reset_history: + self.global_step = 0 + self.last_best_step = 0 + self.dev_score_history = [] + else: + # build model from scratch + self.args = args + self.vocab = vocab + + bert_model, bert_tokenizer = load_bert(self.args['bert_model']) + peft_name = None + if self.args['use_peft']: + # fine tune the bert if we're using peft + self.args['bert_finetune'] = True + peft_name = "depparse" + bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name) + + self.model = Parser(args, vocab, emb_matrix=pretrain.emb if pretrain is not None else None, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name) + self.model = self.model.to(device) + self.__init_optim() + + if ignore_model_config: + self.args = orig_args + + if self.args.get('wandb'): + import wandb + # track gradients! + wandb.watch(self.model, log_freq=4, log="all", log_graph=True) + + def __init_optim(self): + # TODO: can get rid of args.get when models are rebuilt + if (self.args.get("second_stage", False) and self.args.get('second_optim')): + self.optimizer = utils.get_split_optimizer(self.args['second_optim'], self.model, + self.args['second_lr'], betas=(0.9, self.args['beta2']), eps=1e-6, + bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), + is_peft=self.args.get('use_peft', False), + bert_finetune_layers=self.args.get('bert_finetune_layers', None)) + else: + self.optimizer = utils.get_split_optimizer(self.args['optim'], self.model, + self.args['lr'], betas=(0.9, self.args['beta2']), + eps=1e-6, bert_learning_rate=self.args.get('bert_learning_rate', 0.0), + weight_decay=self.args.get('weight_decay', None), + bert_weight_decay=self.args.get('bert_weight_decay', 0.0), + is_peft=self.args.get('use_peft', False), + bert_finetune_layers=self.args.get('bert_finetune_layers', None)) + self.scheduler = {} + if self.args.get("second_stage", False) and self.args.get('second_optim'): + if self.args.get('second_warmup_steps', None): + for name, optimizer in self.optimizer.items(): + name = name + "_scheduler" + warmup_scheduler = transformers.get_constant_schedule_with_warmup(optimizer, self.args['second_warmup_steps']) + self.scheduler[name] = warmup_scheduler + else: + if "bert_optimizer" in self.optimizer: + zero_scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer["bert_optimizer"], factor=0, total_iters=self.args['bert_start_finetuning']) + warmup_scheduler = transformers.get_constant_schedule_with_warmup( + self.optimizer["bert_optimizer"], + self.args['bert_warmup_steps']) + self.scheduler["bert_scheduler"] = torch.optim.lr_scheduler.SequentialLR( + self.optimizer["bert_optimizer"], + schedulers=[zero_scheduler, warmup_scheduler], + milestones=[self.args['bert_start_finetuning']]) + + def update(self, batch, eval=False): + device = next(self.model.parameters()).device + inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) + word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs + + if eval: + self.model.eval() + else: + self.model.train() + for opt in self.optimizer.values(): + opt.zero_grad() + loss, _ = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text) + loss_val = loss.data.item() + if eval: + return loss_val + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) + for opt in self.optimizer.values(): + opt.step() + for scheduler in self.scheduler.values(): + scheduler.step() + return loss_val + + def predict(self, batch, unsort=True): + device = next(self.model.parameters()).device + inputs, orig_idx, word_orig_idx, sentlens, wordlens, text = unpack_batch(batch, device) + word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel = inputs + + self.model.eval() + batch_size = word.size(0) + _, preds = self.model(word, word_mask, wordchars, wordchars_mask, upos, xpos, ufeats, pretrained, lemma, head, deprel, word_orig_idx, sentlens, wordlens, text) + head_seqs = [chuliu_edmonds_one_root(adj[:l, :l])[1:] for adj, l in zip(preds[0], sentlens)] # remove attachment for the root + deprel_seqs = [self.vocab['deprel'].unmap([preds[1][i][j+1][h] for j, h in enumerate(hs)]) for i, hs in enumerate(head_seqs)] + + pred_tokens = [[[str(head_seqs[i][j]), deprel_seqs[i][j]] for j in range(sentlens[i]-1)] for i in range(batch_size)] + if unsort: + pred_tokens = utils.unsort(pred_tokens, orig_idx) + return pred_tokens + + def save(self, filename, skip_modules=True, save_optimizer=False): + model_state = self.model.state_dict() + # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file + if skip_modules: + skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] + for k in skipped: + del model_state[k] + params = { + 'model': model_state, + 'vocab': self.vocab.state_dict(), + 'config': self.args, + 'global_step': self.global_step, + 'last_best_step': self.last_best_step, + 'dev_score_history': self.dev_score_history, + } + if self.args.get('use_peft', False): + # Hide import so that peft dependency is optional + from peft import get_peft_model_state_dict + params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) + + if save_optimizer and self.optimizer is not None: + params['optimizer_state_dict'] = {k: opt.state_dict() for k, opt in self.optimizer.items()} + params['scheduler_state_dict'] = {k: scheduler.state_dict() for k, scheduler in self.scheduler.items()} + + try: + torch.save(params, filename, _use_new_zipfile_serialization=False) + logger.info("Model saved to {}".format(filename)) + except BaseException: + logger.warning("Saving failed... continuing anyway.") + + def load(self, filename, pretrain, args=None, foundation_cache=None, device=None): + """ + Load a model from file, with preloaded pretrain embeddings. Here we allow the pretrain to be None or a dummy input, + and the actual use of pretrain embeddings will depend on the boolean config "pretrain" in the loaded args. + """ + try: + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) + except BaseException: + logger.error("Cannot load model from {}".format(filename)) + raise + self.args = checkpoint['config'] + if args is not None: self.args.update(args) + + # preserve old models which were created before transformers were added + if 'bert_model' not in self.args: + self.args['bert_model'] = None + + lora_weights = checkpoint.get('bert_lora') + if lora_weights: + logger.debug("Found peft weights for depparse; loading a peft adapter") + self.args["use_peft"] = True + + self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) + # load model + emb_matrix = None + if self.args['pretrain'] and pretrain is not None: # we use pretrain only if args['pretrain'] == True and pretrain is not None + emb_matrix = pretrain.emb + + # TODO: refactor this common block of code with NER + force_bert_saved = False + peft_name = None + if self.args.get('use_peft', False): + force_bert_saved = True + bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "depparse", foundation_cache) + bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name) + logger.debug("Loaded peft with name %s", peft_name) + else: + if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): + logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) + foundation_cache = NoTransformerFoundationCache(foundation_cache) + force_bert_saved = True + bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache) + + self.model = Parser(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name) + self.model.load_state_dict(checkpoint['model'], strict=False) + + if device is not None: + self.model = self.model.to(device) + + self.__init_optim() + optim_state_dict = checkpoint.get("optimizer_state_dict") + if optim_state_dict: + for k, state in optim_state_dict.items(): + self.optimizer[k].load_state_dict(state) + + scheduler_state_dict = checkpoint.get("scheduler_state_dict") + if scheduler_state_dict: + for k, state in scheduler_state_dict.items(): + self.scheduler[k].load_state_dict(state) + + self.global_step = checkpoint.get("global_step", 0) + self.last_best_step = checkpoint.get("last_best_step", 0) + self.dev_score_history = checkpoint.get("dev_score_history", list()) diff --git a/stanza/stanza/models/langid/trainer.py b/stanza/stanza/models/langid/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc8258bdc257831f1ffdae5cd1473fb917ac5c2 --- /dev/null +++ b/stanza/stanza/models/langid/trainer.py @@ -0,0 +1,51 @@ +import torch +import torch.optim as optim + +from stanza.models.langid.model import LangIDBiLSTM + + +class Trainer: + + DEFAULT_BATCH_SIZE = 64 + DEFAULT_LAYERS = 2 + DEFAULT_EMBEDDING_DIM = 150 + DEFAULT_HIDDEN_DIM = 150 + + def __init__(self, config, load_model=False, device=None): + self.model_path = config["model_path"] + self.batch_size = config.get("batch_size", Trainer.DEFAULT_BATCH_SIZE) + if load_model: + self.load(config["load_name"], device) + else: + self.model = LangIDBiLSTM(config["char_to_idx"], config["tag_to_idx"], Trainer.DEFAULT_LAYERS, + Trainer.DEFAULT_EMBEDDING_DIM, + Trainer.DEFAULT_HIDDEN_DIM, + batch_size=self.batch_size, + weights=config["lang_weights"]).to(device) + self.optimizer = optim.AdamW(self.model.parameters()) + + def update(self, inputs): + self.model.train() + sentences, targets = inputs + self.optimizer.zero_grad() + y_hat = self.model.forward(sentences) + loss = self.model.loss(y_hat, targets) + loss.backward() + self.optimizer.step() + + def predict(self, inputs): + self.model.eval() + sentences, targets = inputs + return torch.argmax(self.model(sentences), dim=1) + + def save(self, label=None): + # save a copy of model with label + if label: + self.model.save(f"{self.model_path[:-3]}-{label}.pt") + self.model.save(self.model_path) + + def load(self, model_path=None, device=None): + if not model_path: + model_path = self.model_path + self.model = LangIDBiLSTM.load(model_path, device, self.batch_size) + diff --git a/stanza/stanza/models/lemma/__init__.py b/stanza/stanza/models/lemma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/models/lemma/data.py b/stanza/stanza/models/lemma/data.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ff741e89f41c19bf199d30d13c5fbe83bdee01 --- /dev/null +++ b/stanza/stanza/models/lemma/data.py @@ -0,0 +1,212 @@ +import random +import numpy as np +import os +from collections import Counter +import logging +import torch + +import stanza.models.common.seq2seq_constant as constant +from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all +from stanza.models.common.vocab import DeltaVocab +from stanza.models.lemma.vocab import Vocab, MultiVocab +from stanza.models.lemma import edit +from stanza.models.common.doc import * + +logger = logging.getLogger('stanza') + +class DataLoader: + def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, conll_only=False, skip=None, expand_unk_vocab=False): + self.batch_size = batch_size + self.args = args + self.eval = evaluation + self.shuffled = not self.eval + self.doc = doc + + data = self.raw_data() + + if conll_only: # only load conll file + return + + if skip is not None: + assert len(data) == len(skip) + data = [x for x, y in zip(data, skip) if not y] + + # handle vocab + if vocab is not None: + if expand_unk_vocab: + pos_vocab = vocab['pos'] + char_vocab = DeltaVocab(data, vocab['char']) + self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab}) + else: + self.vocab = vocab + else: + self.vocab = dict() + char_vocab, pos_vocab = self.init_vocab(data) + self.vocab = MultiVocab({'char': char_vocab, 'pos': pos_vocab}) + + # filter and sample data + if args.get('sample_train', 1.0) < 1.0 and not self.eval: + keep = int(args['sample_train'] * len(data)) + data = random.sample(data, keep) + logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) + + data = self.preprocess(data, self.vocab['char'], self.vocab['pos'], args) + # shuffle for training + if self.shuffled: + indices = list(range(len(data))) + random.shuffle(indices) + data = [data[i] for i in indices] + self.num_examples = len(data) + + # chunk into batches + data = [data[i:i+batch_size] for i in range(0, len(data), batch_size)] + self.data = data + logger.debug("{} batches created.".format(len(data))) + + def init_vocab(self, data): + assert self.eval is False, "Vocab file must exist for evaluation" + char_data = "".join(d[0] + d[2] for d in data) + char_vocab = Vocab(char_data, self.args['lang']) + pos_data = [d[1] for d in data] + pos_vocab = Vocab(pos_data, self.args['lang']) + return char_vocab, pos_vocab + + def preprocess(self, data, char_vocab, pos_vocab, args): + processed = [] + for d in data: + edit_type = edit.EDIT_TO_ID[edit.get_edit_type(d[0], d[2])] + src = list(d[0]) + src = [constant.SOS] + src + [constant.EOS] + src = char_vocab.map(src) + pos = d[1] + pos = pos_vocab.unit2id(pos) + tgt = list(d[2]) + tgt_in = char_vocab.map([constant.SOS] + tgt) + tgt_out = char_vocab.map(tgt + [constant.EOS]) + processed += [[src, tgt_in, tgt_out, pos, edit_type, d[0]]] + return processed + + def __len__(self): + return len(self.data) + + def __getitem__(self, key): + """ Get a batch with index. """ + if not isinstance(key, int): + raise TypeError + if key < 0 or key >= len(self.data): + raise IndexError + batch = self.data[key] + batch_size = len(batch) + batch = list(zip(*batch)) + assert len(batch) == 6 + + # sort all fields by lens for easy RNN operations + lens = [len(x) for x in batch[0]] + batch, orig_idx = sort_all(batch, lens) + + # convert to tensors + src = batch[0] + src = get_long_tensor(src, batch_size) + src_mask = torch.eq(src, constant.PAD_ID) + tgt_in = get_long_tensor(batch[1], batch_size) + tgt_out = get_long_tensor(batch[2], batch_size) + pos = torch.LongTensor(batch[3]) + edits = torch.LongTensor(batch[4]) + text = batch[5] + assert tgt_in.size(1) == tgt_out.size(1), "Target input and output sequence sizes do not match." + return src, src_mask, tgt_in, tgt_out, pos, edits, orig_idx, text + + def __iter__(self): + for i in range(self.__len__()): + yield self.__getitem__(i) + + def raw_data(self): + return self.load_doc(self.doc, self.args.get('caseless', False), self.eval) + + @staticmethod + def load_doc(doc, caseless, evaluation): + if evaluation: + data = doc.get([TEXT, UPOS, LEMMA]) + else: + data = doc.get([TEXT, UPOS, LEMMA, HEAD, DEPREL, MISC], as_sentences=True) + data = DataLoader.remove_goeswith(data) + data = DataLoader.extract_correct_forms(data) + data = DataLoader.resolve_none(data) + if caseless: + data = DataLoader.lowercase_data(data) + return data + + @staticmethod + def extract_correct_forms(data): + """ + Here we go through the raw data and use the CorrectForm of words tagged with CorrectForm + + In addition, if the incorrect form of the word is not present in the training data, + we keep the incorrect form for the lemmatizer to learn from. + This way, it can occasionally get things right in misspelled input text. + + We do check for and eliminate words where the incorrect form is already known as the + lemma for a different word. For example, in the English datasets, there is a "busy" + which was meant to be "buys", and we don't want the model to learn to lemmatize "busy" to "buy" + """ + new_data = [] + incorrect_forms = [] + for word in data: + misc = word[-1] + if not misc: + new_data.append(word[:3]) + continue + misc = misc.split("|") + for piece in misc: + if piece.startswith("CorrectForm="): + cf = piece.split("=", maxsplit=1)[1] + # treat the CorrectForm as the desired word + new_data.append((cf, word[1], word[2])) + # and save the broken one for later in case it wasn't used anywhere else + incorrect_forms.append((cf, word)) + break + else: + # if no CorrectForm, just keep the word as normal + new_data.append(word[:3]) + known_words = {x[0] for x in new_data} + for correct_form, word in incorrect_forms: + if word[0] not in known_words: + new_data.append(word[:3]) + return new_data + + @staticmethod + def remove_goeswith(data): + """ + This method specifically removes words that goeswith something else, along with the something else + + The purpose is to eliminate text such as + +1 Ken kenrice@enroncommunications X GW Typo=Yes 0 root 0:root _ +2 Rice@ENRON _ X GW _ 1 goeswith 1:goeswith _ +3 COMMUNICATIONS _ X ADD _ 1 goeswith 1:goeswith _ + """ + filtered_data = [] + remove_indices = set() + for sentence in data: + remove_indices.clear() + for word_idx, word in enumerate(sentence): + if word[4] == 'goeswith': + remove_indices.add(word_idx) + remove_indices.add(word[3]-1) + filtered_data.extend([x for idx, x in enumerate(sentence) if idx not in remove_indices]) + return filtered_data + + @staticmethod + def lowercase_data(data): + for token in data: + token[0] = token[0].lower() + return data + + @staticmethod + def resolve_none(data): + # replace None to '_' + for tok_idx in range(len(data)): + for feat_idx in range(len(data[tok_idx])): + if data[tok_idx][feat_idx] is None: + data[tok_idx][feat_idx] = '_' + return data diff --git a/stanza/stanza/models/lemma_classifier/base_model.py b/stanza/stanza/models/lemma_classifier/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc215d435f31d4353365a9dd8e055ff1caee5676 --- /dev/null +++ b/stanza/stanza/models/lemma_classifier/base_model.py @@ -0,0 +1,134 @@ +""" +Base class for the LemmaClassifier types. + +Versions include LSTM and Transformer varieties +""" + +import logging + +from abc import ABC, abstractmethod + +import os + +import torch +import torch.nn as nn + +from stanza.models.common.foundation_cache import load_pretrain +from stanza.models.lemma_classifier.constants import ModelType + +from typing import List + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifier(ABC, nn.Module): + def __init__(self, label_decoder, target_words, target_upos, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.label_decoder = label_decoder + self.label_encoder = {y: x for x, y in label_decoder.items()} + self.target_words = target_words + self.target_upos = target_upos + self.unsaved_modules = [] + + def add_unsaved_module(self, name, module): + self.unsaved_modules += [name] + setattr(self, name, module) + + def is_unsaved_module(self, name): + return name.split('.')[0] in self.unsaved_modules + + def save(self, save_name): + """ + Save the model to the given path, possibly with some args + """ + save_dir = os.path.split(save_name)[0] + if save_dir: + os.makedirs(save_dir, exist_ok=True) + save_dict = self.get_save_dict() + torch.save(save_dict, save_name) + return save_dict + + @abstractmethod + def model_type(self): + """ + return a ModelType + """ + + def target_indices(self, words, tags): + return [idx for idx, (word, tag) in enumerate(zip(words, tags)) if word.lower() in self.target_words and tag in self.target_upos] + + def predict(self, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[str]]=[]) -> torch.Tensor: + upos_tags = self.convert_tags(upos_tags) + with torch.no_grad(): + logits = self.forward(position_indices, sentences, upos_tags) # should be size (batch_size, output_size) + predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) + predicted_class = [self.label_encoder[x.item()] for x in predicted_class] + return predicted_class + + @staticmethod + def from_checkpoint(checkpoint, args=None): + model_type = ModelType[checkpoint['model_type']] + if model_type is ModelType.LSTM: + # TODO: if anyone can suggest a way to avoid this circular import + # (or better yet, avoid the load method knowing about subclasses) + # please do so + # maybe the subclassing is not necessary and we just put + # save & load in the trainer + from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM + + saved_args = checkpoint['args'] + # other model args are part of the model and cannot be changed for evaluation or pipeline + # the file paths might be relevant, though + keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file'] + for arg in keep_args: + if args is not None and args.get(arg, None) is not None: + saved_args[arg] = args[arg] + + # TODO: refactor loading the pretrain (also done in the trainer) + pt = load_pretrain(saved_args['wordvec_pretrain_file']) + + use_charlm = saved_args['use_charlm'] + charlm_forward_file = saved_args.get('charlm_forward_file', None) + charlm_backward_file = saved_args.get('charlm_backward_file', None) + + model = LemmaClassifierLSTM(model_args=saved_args, + output_dim=len(checkpoint['label_decoder']), + pt_embedding=pt, + label_decoder=checkpoint['label_decoder'], + upos_to_id=checkpoint['upos_to_id'], + known_words=checkpoint['known_words'], + target_words=set(checkpoint['target_words']), + target_upos=set(checkpoint['target_upos']), + use_charlm=use_charlm, + charlm_forward_file=charlm_forward_file, + charlm_backward_file=charlm_backward_file) + elif model_type is ModelType.TRANSFORMER: + from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer + + output_dim = len(checkpoint['label_decoder']) + saved_args = checkpoint['args'] + bert_model = saved_args['bert_model'] + model = LemmaClassifierWithTransformer(model_args=saved_args, + output_dim=output_dim, + transformer_name=bert_model, + label_decoder=checkpoint['label_decoder'], + target_words=set(checkpoint['target_words']), + target_upos=set(checkpoint['target_upos'])) + else: + raise ValueError("Unknown model type %s" % model_type) + + # strict=False to accommodate missing parameters from the transformer or charlm + model.load_state_dict(checkpoint['params'], strict=False) + return model + + @staticmethod + def load(filename, args=None): + try: + checkpoint = torch.load(filename, lambda storage, loc: storage) + except BaseException: + logger.exception("Cannot load model from %s", filename) + raise + + logger.debug("Loading LemmaClassifier model from %s", filename) + + return LemmaClassifier.from_checkpoint(checkpoint) diff --git a/stanza/stanza/models/lemma_classifier/baseline_model.py b/stanza/stanza/models/lemma_classifier/baseline_model.py new file mode 100644 index 0000000000000000000000000000000000000000..de4f0f1a6454d6651ab2a89ae4251593f41e0b58 --- /dev/null +++ b/stanza/stanza/models/lemma_classifier/baseline_model.py @@ -0,0 +1,54 @@ +""" +Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token. + +The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token. +""" + +import stanza +import os +from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences +from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file + +class BaselineModel: + + def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos): + self.token_to_lemmatize = token_to_lemmatize + self.prediction_lemma = prediction_lemma + self.prediction_upos = prediction_upos + + def predict(self, token): + if token == self.token_to_lemmatize: + return self.prediction_lemma + + def evaluate(self, conll_path): + """ + Evaluates the baseline model against the test set defined in conll_path. + + Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores + for that class. + + Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags + """ + doc = load_doc_from_conll_file(conll_path) + gold_tag_sequences, pred_tag_sequences = [], [] + for sentence in doc.sentences: + gold_tags, pred_tags = [], [] + for word in sentence.words: + if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize: + pred = self.prediction_lemma + gold = word.lemma + gold_tags.append(gold) + pred_tags.append(pred) + gold_tag_sequences.append(gold_tags) + pred_tag_sequences.append(pred_tags) + + multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences) + return multiclass_result, confusion_mtx + + +if __name__ == "__main__": + + bl_model = BaselineModel("'s", "be", ["AUX"]) + coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu") + bl_model.evaluate(coNLL_path) + diff --git a/stanza/stanza/models/lemma_classifier/lstm_model.py b/stanza/stanza/models/lemma_classifier/lstm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef133adaf48317e456fd30bdfe3e5db5eb04ca8 --- /dev/null +++ b/stanza/stanza/models/lemma_classifier/lstm_model.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +import os +import logging +import math +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence +from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel +from typing import List, Tuple + +from stanza.models.common.vocab import UNK_ID +from stanza.models.lemma_classifier import utils +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.constants import ModelType + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifierLSTM(LemmaClassifier): + """ + Model architecture: + Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding. + From the LSTM output, we get the embedding of the specific token that we classify on. That embedding + is fed into an MLP for classification. + """ + def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, target_upos, + use_charlm=False, charlm_forward_file=None, charlm_backward_file=None): + """ + Args: + vocab_size (int): Size of the vocab being used (if custom vocab) + output_dim (int): Size of output vector from MLP layer + upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs + pt_embedding (Pretrain): pretrained embeddings + known_words (list(str)): Words which are in the training data + target_words (set(str)): a set of the words which might need lemmatization + use_charlm (bool): Whether or not to use the charlm embeddings + charlm_forward_file (str): The path to the forward pass model for the character language model + charlm_backward_file (str): The path to the forward pass model for the character language model. + + Kwargs: + upos_emb_dim (int): The size of the UPOS tag embeddings + num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM. + + Raises: + FileNotFoundError: if the forward or backward charlm file cannot be found. + """ + super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words, target_upos) + self.model_args = model_args + + self.hidden_dim = model_args['hidden_dim'] + self.input_size = 0 + self.num_heads = self.model_args['num_heads'] + + emb_matrix = pt_embedding.emb + self.add_unsaved_module("embeddings", nn.Embedding.from_pretrained(emb_matrix, freeze=True)) + self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) } + self.vocab_size = emb_matrix.shape[0] + self.embedding_dim = emb_matrix.shape[1] + + self.known_words = known_words + self.known_word_map = {word: idx for idx, word in enumerate(known_words)} + self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1, + embedding_dim=self.embedding_dim, + padding_idx=0) + nn.init.normal_(self.delta_embedding.weight, std=0.01) + + self.input_size += self.embedding_dim + + # Optionally, include charlm embeddings + self.use_charlm = use_charlm + + if self.use_charlm: + if charlm_forward_file is None or not os.path.exists(charlm_forward_file): + raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}') + if charlm_backward_file is None or not os.path.exists(charlm_backward_file): + raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}') + self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False)) + self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False)) + + self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() + + self.upos_emb_dim = self.model_args["upos_emb_dim"] + self.upos_to_id = upos_to_id + if self.upos_emb_dim > 0 and self.upos_to_id is not None: + # TODO: should leave space for unknown POS? + self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id), + embedding_dim=self.upos_emb_dim, + padding_idx=0) + self.input_size += self.upos_emb_dim + + device = next(self.parameters()).device + # Determine if attn or LSTM should be used + if self.num_heads > 0: + self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads) + self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device) + logger.debug(f"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.") + else: + self.lstm = nn.LSTM(self.input_size, + self.hidden_dim, + batch_first=True, + bidirectional=True) + logger.debug(f"Using LSTM mechanism.") + + mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size + self.mlp = nn.Sequential( + nn.Linear(mlp_input_size, 64), + nn.ReLU(), + nn.Linear(64, output_dim) + ) + + def get_save_dict(self): + save_dict = { + "params": self.state_dict(), + "label_decoder": self.label_decoder, + "model_type": self.model_type().name, + "args": self.model_args, + "upos_to_id": self.upos_to_id, + "known_words": self.known_words, + "target_words": list(self.target_words), + "target_upos": list(self.target_upos), + } + skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] + for k in skipped: + del save_dict["params"][k] + return save_dict + + def convert_tags(self, upos_tags: List[List[str]]): + if self.upos_to_id is not None: + return [[self.upos_to_id[x] for x in sentence] for sentence in upos_tags] + return None + + def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]): + """ + Computes the forward pass of the neural net + + Args: + pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence. + sentences (List[List[str]]): A list of the token-split sentences of the input data. + upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence. + + Returns: + torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences. + """ + device = next(self.parameters()).device + batch_size = len(sentences) + token_ids = [] + delta_token_ids = [] + for words in sentences: + sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words] + sentence_token_ids = torch.tensor(sentence_token_ids, device=device) + token_ids.append(sentence_token_ids) + + sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words] + sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device) + delta_token_ids.append(sentence_delta_token_ids) + + token_ids = pad_sequence(token_ids, batch_first=True) + delta_token_ids = pad_sequence(delta_token_ids, batch_first=True) + embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids) + + if self.upos_emb_dim > 0: + upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags] # convert internal lists to tensors + upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device) + pos_emb = self.upos_emb(upos_tags) + embedded = torch.cat((embedded, pos_emb), 2).to(device) + + if self.use_charlm: + char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]] + char_reps_backward = self.charmodel_backward.build_char_representation(sentences) + + char_reps_forward = pad_sequence(char_reps_forward, batch_first=True) + char_reps_backward = pad_sequence(char_reps_backward, batch_first=True) + + embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2) + + if self.num_heads > 0: + + def positional_encoding(seq_len, d_model, device): + encoding = torch.zeros(seq_len, d_model, device=device) + position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device) + + encoding[:, 0::2] = torch.sin(position * div_term) + encoding[:, 1::2] = torch.cos(position * div_term) + + # Add a new dimension to fit the batch size + encoding = encoding.unsqueeze(0) + return encoding + + seq_len, d_model = embedded.shape[1], embedded.shape[2] + pos_enc = positional_encoding(seq_len, d_model, device=device) + + embedded += pos_enc.expand_as(embedded) + + padded_sequences = pad_sequence(embedded, batch_first=True) + lengths = torch.tensor([len(seq) for seq in embedded]) + + if self.num_heads > 0: + target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1) + attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1) + + attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length) + attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device) + + attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask) + # Extract the hidden state at the index of the token to classify + token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices] + + else: + packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True) + lstm_out, (hidden, _) = self.lstm(packed_sequences) + # Extract the hidden state at the index of the token to classify + unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True) + token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices] + + # MLP forward pass + output = self.mlp(token_reps) + return output + + def model_type(self): + return ModelType.LSTM diff --git a/stanza/stanza/models/mwt/data.py b/stanza/stanza/models/mwt/data.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6f73f855c459d838a787e9be818760d0a9f5db --- /dev/null +++ b/stanza/stanza/models/mwt/data.py @@ -0,0 +1,182 @@ +import random +import numpy as np +import os +from collections import Counter, namedtuple +import logging + +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader as DL + +import stanza.models.common.seq2seq_constant as constant +from stanza.models.common.data import map_to_ids, get_long_tensor, get_float_tensor, sort_all +from stanza.models.common.vocab import DeltaVocab +from stanza.models.mwt.vocab import Vocab +from stanza.models.common.doc import Document + +logger = logging.getLogger('stanza') + +DataSample = namedtuple("DataSample", "src tgt_in tgt_out orig_text") +DataBatch = namedtuple("DataBatch", "src src_mask tgt_in tgt_out orig_text orig_idx") + +# enforce that the MWT splitter knows about a couple different alternate apostrophes +# including covering some potential " typos +# setting the augmentation to a very low value should be enough to teach it +# about the unknown characters without messing up the predictions for other text +# +# 0x22, 0x27, 0x02BC, 0x02CA, 0x055A, 0x07F4, 0x2019, 0xFF07 +APOS = ('"', "'", 'ʼ', 'ˊ', '՚', 'ߴ', '’', ''') + +class DataLoader: + def __init__(self, doc, batch_size, args, vocab=None, evaluation=False, expand_unk_vocab=False): + self.batch_size = batch_size + self.args = args + self.augment_apos = args.get('augment_apos', 0.0) + self.evaluation = evaluation + self.doc = doc + + data = self.load_doc(self.doc, evaluation=self.evaluation) + + # handle vocab + if vocab is None: + assert self.evaluation == False # for eval vocab must exist + self.vocab = self.init_vocab(data) + if self.augment_apos > 0 and any(x in self.vocab for x in APOS): + for apos in APOS: + self.vocab.add_unit(apos) + elif expand_unk_vocab: + self.vocab = DeltaVocab(data, vocab) + else: + self.vocab = vocab + + # filter and sample data + if args.get('sample_train', 1.0) < 1.0 and not self.evaluation: + keep = int(args['sample_train'] * len(data)) + data = random.sample(data, keep) + logger.debug("Subsample training set with rate {:g}".format(args['sample_train'])) + + # shuffle for training + if not self.evaluation: + indices = list(range(len(data))) + random.shuffle(indices) + data = [data[i] for i in indices] + + self.data = data + self.num_examples = len(data) + + def init_vocab(self, data): + assert self.evaluation == False # for eval vocab must exist + vocab = Vocab(data, self.args['shorthand']) + return vocab + + def maybe_augment_apos(self, datum): + for original in APOS: + if original in datum[0]: + if random.uniform(0,1) < self.augment_apos: + replacement = random.choice(APOS) + datum = (datum[0].replace(original, replacement), datum[1].replace(original, replacement)) + break + return datum + + def process(self, sample): + if not self.evaluation and self.augment_apos > 0: + sample = self.maybe_augment_apos(sample) + src = list(sample[0]) + src = [constant.SOS] + src + [constant.EOS] + tgt_in, tgt_out = self.prepare_target(self.vocab, sample) + src = self.vocab.map(src) + processed = [src, tgt_in, tgt_out, sample[0]] + return processed + + def prepare_target(self, vocab, datum): + if self.evaluation: + tgt = list(datum[0]) # as a placeholder + else: + tgt = list(datum[1]) + tgt_in = vocab.map([constant.SOS] + tgt) + tgt_out = vocab.map(tgt + [constant.EOS]) + return tgt_in, tgt_out + + def __len__(self): + return len(self.data) + + def __getitem__(self, key): + """ Get a batch with index. """ + if not isinstance(key, int): + raise TypeError + if key < 0 or key >= len(self.data): + raise IndexError + sample = self.data[key] + sample = self.process(sample) + assert len(sample) == 4 + + src = torch.tensor(sample[0]) + tgt_in = torch.tensor(sample[1]) + tgt_out = torch.tensor(sample[2]) + orig_text = sample[3] + result = DataSample(src, tgt_in, tgt_out, orig_text), key + return result + + @staticmethod + def __collate_fn(data): + (data, idx) = zip(*data) + (src, tgt_in, tgt_out, orig_text) = zip(*data) + + # collate_fn is given a list of length batch size + batch_size = len(data) + + # need to sort by length of src to properly handle + # the batching in the model itself + lens = [len(x) for x in src] + (src, tgt_in, tgt_out, orig_text), orig_idx = sort_all((src, tgt_in, tgt_out, orig_text), lens) + lens = [len(x) for x in src] + + # convert to tensors + src = pad_sequence(src, True, constant.PAD_ID) + src_mask = torch.eq(src, constant.PAD_ID) + tgt_in = pad_sequence(tgt_in, True, constant.PAD_ID) + tgt_out = pad_sequence(tgt_out, True, constant.PAD_ID) + assert tgt_in.size(1) == tgt_out.size(1), \ + "Target input and output sequence sizes do not match." + return DataBatch(src, src_mask, tgt_in, tgt_out, orig_text, orig_idx) + + def __iter__(self): + for i in range(self.__len__()): + yield self.__getitem__(i) + + def to_loader(self): + """Converts self to a DataLoader """ + + batch_size = self.batch_size + shuffle = not self.evaluation + return DL(self, + collate_fn=self.__collate_fn, + batch_size=batch_size, + shuffle=shuffle) + + def load_doc(self, doc, evaluation=False): + data = doc.get_mwt_expansions(evaluation) + if evaluation: data = [[e] for e in data] + return data + +class BinaryDataLoader(DataLoader): + """ + This version of the DataLoader performs the same tasks as the regular DataLoader, + except the targets are arrays of 0/1 indicating if the character is the location + of an MWT split + """ + def prepare_target(self, vocab, datum): + src = datum[0] if self.evaluation else datum[1] + binary = [0] + has_space = False + for char in src: + if char == ' ': + has_space = True + elif has_space: + has_space = False + binary.append(1) + else: + binary.append(0) + binary.append(0) + return binary, binary + diff --git a/stanza/stanza/models/mwt/scorer.py b/stanza/stanza/models/mwt/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf67974967c1e9ff4af69d29ef5c234910782868 --- /dev/null +++ b/stanza/stanza/models/mwt/scorer.py @@ -0,0 +1,12 @@ +""" +Utils and wrappers for scoring MWT +""" +from stanza.models.common.utils import ud_scores + +def score(system_conllu_file, gold_conllu_file): + """ Wrapper for word segmenter scorer. """ + evaluation = ud_scores(gold_conllu_file, system_conllu_file) + el = evaluation["Words"] + p, r, f = el.precision, el.recall, el.f1 + return p, r, f + diff --git a/stanza/stanza/models/mwt/utils.py b/stanza/stanza/models/mwt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fc861ab35b2e72af293d928711ff617831d5eb --- /dev/null +++ b/stanza/stanza/models/mwt/utils.py @@ -0,0 +1,92 @@ +import stanza + +from stanza.models.common import doc +from stanza.models.tokenization.data import TokenizationDataset +from stanza.models.tokenization.utils import predict, decode_predictions + +def mwts_composed_of_words(doc): + """ + Return True/False if the MWTs in the doc are all exactly composed of the text in their words + """ + for sent_idx, sentence in enumerate(doc.sentences): + for token_idx, token in enumerate(sentence.tokens): + if len(token.words) > 1: + expected = "".join(x.text for x in token.words) + if token.text != expected: + return False + return True + + +def resplit_mwt(tokens, pipeline, keep_tokens=True): + """ + Uses the tokenize processor and the mwt processor in the pipeline to resplit tokens into MWT + + tokens: a list of list of string + pipeline: a Stanza pipeline which contains, at a minimum, tokenize and mwt + + keep_tokens: if True, enforce the old token boundaries by modify + the results of the tokenize inference. + Otherwise, use whatever new boundaries the model comes up with. + + between running the tokenize model and breaking the text into tokens, + we can update all_preds to use the original token boundaries + (if and only if keep_tokens == True) + + This method returns a Document with just the tokens and words annotated. + """ + if "tokenize" not in pipeline.processors: + raise ValueError("Need a Pipeline with a valid tokenize processor") + if "mwt" not in pipeline.processors: + raise ValueError("Need a Pipeline with a valid mwt processor") + tokenize_processor = pipeline.processors["tokenize"] + mwt_processor = pipeline.processors["mwt"] + fake_text = "\n\n".join(" ".join(sentence) for sentence in tokens) + + # set up batches + batches = TokenizationDataset(tokenize_processor.config, + input_text=fake_text, + vocab=tokenize_processor.vocab, + evaluation=True, + dictionary=tokenize_processor.trainer.dictionary) + + all_preds, all_raw = predict(trainer=tokenize_processor.trainer, + data_generator=batches, + batch_size=tokenize_processor.trainer.args['batch_size'], + max_seqlen=tokenize_processor.config.get('max_seqlen', tokenize_processor.MAX_SEQ_LENGTH_DEFAULT), + use_regex_tokens=True, + num_workers=tokenize_processor.config.get('num_workers', 0)) + + if keep_tokens: + for sentence, pred in zip(tokens, all_preds): + char_idx = 0 + for word in sentence: + if len(word) > 0: + pred[char_idx:char_idx+len(word)-1] = 0 + if pred[char_idx+len(word)-1] == 0: + pred[char_idx+len(word)-1] = 1 + char_idx += len(word) + 1 + + _, _, document = decode_predictions(vocab=tokenize_processor.vocab, + mwt_dict=None, + orig_text=fake_text, + all_raw=all_raw, + all_preds=all_preds, + no_ssplit=True, + skip_newline=tokenize_processor.trainer.args['skip_newline'], + use_la_ittb_shorthand=tokenize_processor.trainer.args['shorthand'] == 'la_ittb') + + document = doc.Document(document, fake_text) + mwt_processor.process(document) + return document + +def main(): + pipe = stanza.Pipeline("en", processors="tokenize,mwt", package="gum") + tokens = [["I", "can't", "believe", "it"], ["I can't", "sleep"]] + doc = resplit_mwt(tokens, pipe) + print(doc) + + doc = resplit_mwt(tokens, pipe, keep_tokens=False) + print(doc) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/models/ner/__init__.py b/stanza/stanza/models/ner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/models/ner/trainer.py b/stanza/stanza/models/ner/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8137336b97d3b834a6949419a6eb2833225ca5d1 --- /dev/null +++ b/stanza/stanza/models/ner/trainer.py @@ -0,0 +1,268 @@ +""" +A trainer class to handle training and testing of models. +""" + +import sys +import logging +import torch +from torch import nn + +from stanza.models.common.foundation_cache import NoTransformerFoundationCache, load_bert, load_bert_with_peft +from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper +from stanza.models.common.trainer import Trainer as BaseTrainer +from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE +from stanza.models.common import utils, loss +from stanza.models.ner.model import NERTagger +from stanza.models.ner.vocab import MultiVocab +from stanza.models.common.crf import viterbi_decode + + +logger = logging.getLogger('stanza') + +def unpack_batch(batch, device): + """ Unpack a batch from the data loader. """ + inputs = [batch[0]] + inputs += [b.to(device) if b is not None else None for b in batch[1:5]] + orig_idx = batch[5] + word_orig_idx = batch[6] + char_orig_idx = batch[7] + sentlens = batch[8] + wordlens = batch[9] + charlens = batch[10] + charoffsets = batch[11] + return inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets + +def fix_singleton_tags(tags): + """ + If there are any singleton B- or E- tags, convert them to S- + """ + new_tags = list(tags) + # first update all I- tags at the start or end of sequence to B- or E- as appropriate + for idx, tag in enumerate(new_tags): + if (tag.startswith("I-") and + (idx == len(new_tags) - 1 or + (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))): + new_tags[idx] = "E-" + tag[2:] + if (tag.startswith("I-") and + (idx == 0 or + (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))): + new_tags[idx] = "B-" + tag[2:] + # now make another pass through the data to update any singleton tags, + # including ones which were turned into singletons by the previous operation + for idx, tag in enumerate(new_tags): + if (tag.startswith("B-") and + (idx == len(new_tags) - 1 or + (new_tags[idx+1] != "I-" + tag[2:] and new_tags[idx+1] != "E-" + tag[2:]))): + new_tags[idx] = "S-" + tag[2:] + if (tag.startswith("E-") and + (idx == 0 or + (new_tags[idx-1] != "B-" + tag[2:] and new_tags[idx-1] != "I-" + tag[2:]))): + new_tags[idx] = "S-" + tag[2:] + return new_tags + +class Trainer(BaseTrainer): + """ A trainer for training models. """ + def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device=None, + train_classifier_only=False, foundation_cache=None, second_optim=False): + if model_file is not None: + # load everything from file + self.load(model_file, pretrain, args, foundation_cache) + else: + assert all(var is not None for var in [args, vocab, pretrain]) + # build model from scratch + self.args = args + self.vocab = vocab + bert_model, bert_tokenizer = load_bert(self.args['bert_model']) + peft_name = None + if self.args['use_peft']: + # fine tune the bert if we're using peft + self.args['bert_finetune'] = True + peft_name = "ner" + # peft the lovely model + bert_model = build_peft_wrapper(bert_model, self.args, logger, adapter_name=peft_name) + + self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=self.args['bert_finetune'], peft_name=peft_name) + + # IMPORTANT: gradient checkpointing BREAKS peft if applied before + # 1. Apply PEFT FIRST (looksie! it's above this line) + # 2. Run gradient checkpointing + # https://github.com/huggingface/peft/issues/742 + if self.args.get("gradient_checkpointing", False) and self.args.get("bert_finetune", False): + self.model.bert_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + + # if this wasn't set anywhere, we use a default of the 0th tagset + # we don't set this as a default in the options so that + # we can distinguish "intentionally set to 0" and "not set at all" + if self.args.get('predict_tagset', None) is None: + self.args['predict_tagset'] = 0 + + if train_classifier_only: + logger.info('Disabling gradient for non-classifier layers') + exclude = ['tag_clf', 'crit'] + for pname, p in self.model.named_parameters(): + if pname.split('.')[0] not in exclude: + p.requires_grad = False + self.model = self.model.to(device) + if not second_optim: + self.optimizer = utils.get_optimizer(self.args['optim'], self.model, self.args['lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('bert_learning_rate', 0.0), is_peft=self.args.get("use_peft")) + else: + self.optimizer = utils.get_optimizer(self.args['second_optim'], self.model, self.args['second_lr'], momentum=self.args['momentum'], bert_learning_rate=self.args.get('second_bert_learning_rate', 0.0), is_peft=self.args.get("use_peft")) + + def update(self, batch, eval=False): + device = next(self.model.parameters()).device + inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device) + word, wordchars, wordchars_mask, chars, tags = inputs + + if eval: + self.model.eval() + else: + self.model.train() + self.optimizer.zero_grad() + loss, _, _ = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx) + loss_val = loss.data.item() + if eval: + return loss_val + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) + self.optimizer.step() + return loss_val + + def predict(self, batch, unsort=True): + device = next(self.model.parameters()).device + inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(batch, device) + word, wordchars, wordchars_mask, chars, tags = inputs + + self.model.eval() + #batch_size = word.size(0) + _, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx) + + # decode + # TODO: might need to decode multiple columns of output for + # models with multiple layers + trans = [x.data.cpu().numpy() for x in trans] + logits = [x.data.cpu().numpy() for x in logits] + batch_size = logits[0].shape[0] + if any(x.shape[0] != batch_size for x in logits): + raise AssertionError("Expected all of the logits to have the same size") + tag_seqs = [] + predict_tagset = self.args['predict_tagset'] + for i in range(batch_size): + # for each tag column in the output, decode the tag assignments + tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)] + # TODO: this is to patch that the model can sometimes predict < "O" + tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags] + # that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags + tags = list(zip(*tags)) + # now unmap that to the tags in the vocab + tags = self.vocab['tag'].unmap(tags) + # for now, allow either TagVocab or CompositeVocab + # TODO: we might want to return all of the predictions + # rather than a single column + tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags] + tags = fix_singleton_tags(tags) + tag_seqs += [tags] + + if unsort: + tag_seqs = utils.unsort(tag_seqs, orig_idx) + return tag_seqs + + def save(self, filename, skip_modules=True): + model_state = self.model.state_dict() + # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file + if skip_modules: + skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] + for k in skipped: + del model_state[k] + params = { + 'model': model_state, + 'vocab': self.vocab.state_dict(), + 'config': self.args + } + + if self.args["use_peft"]: + from peft import get_peft_model_state_dict + params["bert_lora"] = get_peft_model_state_dict(self.model.bert_model, adapter_name=self.model.peft_name) + try: + torch.save(params, filename, _use_new_zipfile_serialization=False) + logger.info("Model saved to {}".format(filename)) + except (KeyboardInterrupt, SystemExit): + raise + except: + logger.warning("Saving failed... continuing anyway.") + + def load(self, filename, pretrain=None, args=None, foundation_cache=None): + try: + checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) + except BaseException: + logger.error("Cannot load model from {}".format(filename)) + raise + self.args = checkpoint['config'] + if args: self.args.update(args) + # if predict_tagset was not explicitly set in the args, + # we use the value the model was trained with + for keep_arg in ('predict_tagset', 'train_scheme', 'scheme'): + if self.args.get(keep_arg, None) is None: + self.args[keep_arg] = checkpoint['config'].get(keep_arg, None) + + lora_weights = checkpoint.get('bert_lora') + if lora_weights: + logger.debug("Found peft weights for NER; loading a peft adapter") + self.args["use_peft"] = True + + self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) + + emb_matrix=None + if pretrain is not None: + emb_matrix = pretrain.emb + + force_bert_saved = False + peft_name = None + if self.args.get('use_peft', False): + force_bert_saved = True + bert_model, bert_tokenizer, peft_name = load_bert_with_peft(self.args['bert_model'], "ner", foundation_cache) + bert_model = load_peft_wrapper(bert_model, lora_weights, self.args, logger, peft_name) + logger.debug("Loaded peft with name %s", peft_name) + else: + if any(x.startswith("bert_model.") for x in checkpoint['model'].keys()): + logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename) + foundation_cache = NoTransformerFoundationCache(foundation_cache) + force_bert_saved = True + bert_model, bert_tokenizer = load_bert(self.args.get('bert_model'), foundation_cache) + + if any(x.startswith("crit.") for x in checkpoint['model'].keys()): + logger.debug("Old model format detected. Updating to the new format with one column of tags") + checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions') + checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight') + checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias') + self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, bert_model=bert_model, bert_tokenizer=bert_tokenizer, force_bert_saved=force_bert_saved, peft_name=peft_name) + self.model.load_state_dict(checkpoint['model'], strict=False) + + # there is a possible issue with the delta embeddings. + # specifically, with older models trained without the delta + # embedding matrix + # if those models have been trained with the embedding + # modifications saved as part of the base embedding, + # we need to resave the model with the updated embedding + # otherwise the resulting model will be broken + if 'delta' not in self.model.vocab and 'word_emb.weight' in checkpoint['model'].keys() and 'word_emb' in self.model.unsaved_modules: + logger.debug("Removing word_emb from unsaved_modules so that resaving %s will keep the saved embedding", filename) + self.model.unsaved_modules.remove('word_emb') + + def get_known_tags(self): + """ + Return the tags known by this model + + Removes the S-, B-, etc, and does not include O + """ + tags = set() + for tag in self.vocab['tag'].items(0): + if tag in VOCAB_PREFIX: + continue + if tag == 'O': + continue + if len(tag) > 2 and tag[:2] in ('S-', 'B-', 'I-', 'E-'): + tag = tag[2:] + tags.add(tag) + return sorted(tags) diff --git a/stanza/stanza/models/tokenization/vocab.py b/stanza/stanza/models/tokenization/vocab.py new file mode 100644 index 0000000000000000000000000000000000000000..8800bed3a377292f89fc02bff22107d6d095a9c5 --- /dev/null +++ b/stanza/stanza/models/tokenization/vocab.py @@ -0,0 +1,35 @@ +from collections import Counter +import re + +from stanza.models.common.vocab import BaseVocab +from stanza.models.common.vocab import UNK, PAD + +SPACE_RE = re.compile(r'\s') + +class Vocab(BaseVocab): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lang_replaces_spaces = any([self.lang.startswith(x) for x in ['zh', 'ja', 'ko']]) + + def build_vocab(self): + paras = self.data + counter = Counter() + for para in paras: + for unit in para: + normalized = self.normalize_unit(unit[0]) + counter[normalized] += 1 + + self._id2unit = [PAD, UNK] + list(sorted(list(counter.keys()), key=lambda k: counter[k], reverse=True)) + self._unit2id = {w:i for i, w in enumerate(self._id2unit)} + + def normalize_unit(self, unit): + # Normalize minimal units used by the tokenizer + return unit + + def normalize_token(self, token): + token = SPACE_RE.sub(' ', token.lstrip()) + + if self.lang_replaces_spaces: + token = token.replace(' ', '') + + return token diff --git a/stanza/stanza/pipeline/demo/README.md b/stanza/stanza/pipeline/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eceb807502e4e1d2d5ac9475cace14af55fad4c3 --- /dev/null +++ b/stanza/stanza/pipeline/demo/README.md @@ -0,0 +1,23 @@ +## Interactive Demo for Stanza + +### Requirements + +stanza, flask + +### Run the demo locally + +1. Make sure you know how to disable your browser's CORS rule. For Chrome, [this extension](https://mybrowseraddon.com/access-control-allow-origin.html) works pretty well. +2. From this directory, start the Stanza demo server + +```bash +export FLASK_APP=demo_server.py +flask run +``` + +3. In `stanza-brat.js`, uncomment the line at the top that declares `serverAddress` and point it to where your flask is serving the demo server (usually `http://localhost:5000`) + +4. Open `stanza-brat.html` in your browser (with CORS disabled) and enjoy! + +### Common issues + +Make sure you have the models corresponding to the language you want to test out locally before submitting requests to the server! (Models can be obtained by `import stanza; stanza.download()`. diff --git a/stanza/stanza/utils/datasets/constituency/__init__.py b/stanza/stanza/utils/datasets/constituency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/datasets/constituency/common_trees.py b/stanza/stanza/utils/datasets/constituency/common_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cc399e9aa698153ac83a905c88219a3772d099 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/common_trees.py @@ -0,0 +1,23 @@ +""" +Look through 2 files, only output the common trees + +pretty basic - could use some more options +""" + +import sys + +def main(): + in1 = sys.argv[1] + with open(in1, encoding="utf-8") as fin: + lines1 = fin.readlines() + in2 = sys.argv[2] + with open(in2, encoding="utf-8") as fin: + lines2 = fin.readlines() + + common = [l1 for l1, l2 in zip(lines1, lines2) if l1 == l2] + for l in common: + print(l.strip()) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/constituency/convert_alt.py b/stanza/stanza/utils/datasets/constituency/convert_alt.py new file mode 100644 index 0000000000000000000000000000000000000000..39e847b43ec1f6b846072187dd46c90f60329fb0 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_alt.py @@ -0,0 +1,100 @@ +""" +Read files of parses and the files which define the train/dev/test splits + +Write out the files after splitting them + +Sequence of operations: + - read the raw lines from the input files + - read the recommended splits, as per the ALT description page + - separate the trees using the recommended split files + - write back the trees +""" + +def read_split_file(split_file): + """ + Read a split file for ALT + + The format of the file is expected to be a list of lines such as + URL.1234 + Here, we only care about the id + + return: a set of the ids + """ + with open(split_file, encoding="utf-8") as fin: + lines = fin.readlines() + lines = [x.strip() for x in lines] + lines = [x.split()[0] for x in lines if x] + if any(not x.startswith("URL.") for x in lines): + raise ValueError("Unexpected line in %s: %s" % (split_file, x)) + split = set(int(x.split(".", 1)[1]) for x in lines) + return split + +def split_trees(all_lines, splits): + """ + Splits lines of the form + SNT.17873.4049 (S ... + then assigns them to a list based on the file id in + SNT.. + """ + trees = [list() for _ in splits] + for line in all_lines: + tree_id, tree_text = line.split(maxsplit=1) + tree_id = int(tree_id.split(".", 2)[1]) + for split_idx, split in enumerate(splits): + if tree_id in split: + trees[split_idx].append(tree_text) + break + else: + # couldn't figure out which split to put this in + raise ValueError("Couldn't find which split this line goes in:\n%s" % line) + return trees + +def read_alt_lines(input_files): + """ + Read the trees from the given file(s) + + Any trees with wide spaces are eliminated. The parse tree + handling doesn't handle it well and the tokenizer won't produce + tokens which are entirely wide spaces anyway + + The tree lines are not processed into trees, though + """ + all_lines = [] + for input_file in input_files: + with open(input_file, encoding="utf-8") as fin: + all_lines.extend(fin.readlines()) + all_lines = [x.strip() for x in all_lines] + all_lines = [x for x in all_lines if x] + original_count = len(all_lines) + # there is 1 tree with wide space as an entire token, and 4 with wide spaces at the end of a token + all_lines = [x for x in all_lines if not " " in x] + new_count = len(all_lines) + if new_count < original_count: + print("Eliminated %d trees for having wide spaces in it" % ((original_count - new_count))) + original_count = new_count + all_lines = [x for x in all_lines if not "\\x" in x] + new_count = len(all_lines) + if new_count < original_count: + print("Eliminated %d trees for not being correctly encoded" % ((original_count - new_count))) + original_count = new_count + return all_lines + +def convert_alt(input_files, split_files, output_files): + """ + Convert the ALT treebank into train/dev/test splits + + input_files: paths to read trees + split_files: recommended splits from the ALT page + output_files: where to write train/dev/test + """ + all_lines = read_alt_lines(input_files) + + splits = [read_split_file(split_file) for split_file in split_files] + trees = split_trees(all_lines, splits) + + for chunk, output_file in zip(trees, output_files): + print("Writing %d trees to %s" % (len(chunk), output_file)) + with open(output_file, "w", encoding="utf-8") as fout: + for tree in chunk: + # the extra ROOT is because the ALT doesn't have this at the top of its trees + fout.write("(ROOT {})\n".format(tree)) diff --git a/stanza/stanza/utils/datasets/constituency/convert_arboretum.py b/stanza/stanza/utils/datasets/constituency/convert_arboretum.py new file mode 100644 index 0000000000000000000000000000000000000000..57f6417f75c72f6a030146f3020abf04fb9c04f8 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_arboretum.py @@ -0,0 +1,443 @@ +""" +Parses a Tiger dataset to PTB + +Also handles problems specific for the Arboretum treebank. + +- validation errors in the XML: + -- there is a "&" instead of an "&" early on + -- there are tags "<{note}>" and "<{parentes-udeladt}>" which may or may not be relevant, + but are definitely not properly xml encoded +- trees with stranded nodes. 5 trees have links to words in a different tree. + those trees are skipped +- trees with empty nodes. 58 trees have phrase nodes with no leaves. + those trees are skipped +- trees with missing words. 134 trees have words in the text which aren't in the tree + those trees are also skipped +- trees with categories not in the category directory + for example, intj... replaced with fcl? + most of these are replaced with what might be a sensible replacement +- trees with labels that don't have an obvious replacement + these trees are eliminated, 4 total +- underscores in words. those words are split into multiple words + the tagging is not going to be ideal, but the first step of training + a parser is usually to retag the words anyway, so this should be okay +- tree 14729 is really weirdly annotated. skipped +- 5373 trees total have non-projective constituents. These don't work + with the stanza parser... in order to work around this, we rearrange + them when possible. + ((X Z) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 3021 trees + ((X Z1 ...) Y1 Y2 ...) -> (X Y1 Y2 Z) this rearranges 403 trees + ((X Z1 ...) (tag Y1) ...) -> (X (Y1) Z) this rearranges 1258 trees + + A couple examples of things which get rearranged + (limited in scope and without the words to avoid breaking our license): + +(vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7) +--> +(vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9)) + +(vp (v-fin s1_2) (v-pcp2 s1_4)) (adv s1_3) +--> +(vp (v-fin s1_2) (adv s1_3) (v-pcp2 s1_4)) + + This process leaves behind 691 trees. In some cases, the + non-projective structure is at a higher level than the attachment. + In others, there are nested non-projectivities that are not + rearranged by the above pattern. A couple examples: + +here, the 3-7 nonprojectivity has the 7 in a nested structure +(s + (par + (n s206_1) + (pu s206_2) + (fcl + (fcl + (pron-pers s206_3) + (fcl (pron-pers s206_7) (adv s206_8) (v-fin s206_9))) + (vp (v-fin s206_4) (v-inf s206_6)) + (pron-pers s206_5)) + (pu s206_10))) + +here, 11 is attached at a higher level than 12 & 13 +(s + (fcl + (icl + (np + (adv s223_1) + (np + (n s223_2) + (pp + (prp s223_3) + (par + (adv s223_4) + (prop s223_5) + (pu s223_6) + (prop s223_7) + (conj-c s223_8) + (np (adv s223_9) (prop s223_10)))))) + (vp (infm s223_12) (v-inf s223_13))) + (v-fin s223_11) + (pu s223_14))) + +even if we moved _6 between 2 and 7, we'd then have a completely flat +structure when moving 3..5 inside +(s + (fcl + (xx s499_1) + (np + (pp (pron-pers s499_2) (prp s499_7)) + (n s499_6)) + (v-fin s499_3) (adv s499_4) (adv s499_5) (pu s499_8))) + +""" + + +from collections import namedtuple +import io +import xml.etree.ElementTree as ET + +from tqdm import tqdm + +from stanza.models.constituency.parse_tree import Tree +from stanza.server import tsurgeon + +def read_xml_file(input_filename): + """ + Convert an XML file into a list of trees - each becomes its own object + """ + print("Reading {}".format(input_filename)) + with open(input_filename, encoding="utf-8") as fin: + lines = fin.readlines() + + sentences = [] + current_sentence = [] + in_sentence = False + for line_idx, line in enumerate(lines): + if line.startswith(" 0: + raise ValueError("Found the start of a sentence inside an existing sentence, line {}".format(line_idx)) + in_sentence = True + + if in_sentence: + current_sentence.append(line) + + if line.startswith(""): + assert in_sentence + current_sentence = [x.replace("<{parentes-udeladt}>", "") for x in current_sentence] + current_sentence = [x.replace("<{note}>", "") for x in current_sentence] + sentences.append("".join(current_sentence)) + current_sentence = [] + in_sentence = False + + assert len(current_sentence) == 0 + + xml_sentences = [] + for sent_idx, text in enumerate(sentences): + sentence = io.StringIO(text) + try: + tree = ET.parse(sentence) + xml_sentences.append(tree) + except ET.ParseError as e: + raise ValueError("Failed to parse sentence {}".format(sent_idx)) + + return xml_sentences + +Word = namedtuple('Word', ['word', 'tag']) +Node = namedtuple('Node', ['label', 'children']) + +class BrokenLinkError(ValueError): + def __init__(self, error): + super(BrokenLinkError, self).__init__(error) + +def process_nodes(root_id, words, nodes, visited): + """ + Given a root_id, a map of words, and a map of nodes, construct a Tree + + visited is a set of string ids and mutates over the course of the recursive call + """ + if root_id in visited: + raise ValueError("Loop in the tree!") + visited.add(root_id) + + if root_id in words: + word = words[root_id] + # big brain move: put the root_id here so we can use that to + # check the sorted order when we are done + word_node = Tree(label=root_id) + tag_node = Tree(label=word.tag, children=word_node) + return tag_node + elif root_id in nodes: + node = nodes[root_id] + children = [process_nodes(child, words, nodes, visited) for child in node.children] + return Tree(label=node.label, children=children) + else: + raise BrokenLinkError("Unknown id! {}".format(root_id)) + +def check_words(tree, tsurgeon_processor): + """ + Check that the words of a sentence are in order + + If they are not, this applies a tsurgeon to rearrange simple cases + The tsurgeon looks at the gap between words, eg _3 to _7, and looks + for the words between, such as _4 _5 _6. if those words are under + a node at the same level as the 3-7 node and does not include any + other nodes (such as _8), that subtree is moved to between _3 and _7 + + Example: + + (vp (v-fin s4_6) (conj-c s4_8) (v-fin s4_9)) (pron-pers s4_7) + --> + (vp (v-fin s4_6) (pron-pers s4_7) (conj-c s4_8) (v-fin s4_9)) + """ + while True: + words = tree.leaf_labels() + indices = [int(w.split("_", 1)[1]) for w in words] + for word_idx, word_label in enumerate(indices): + if word_idx != word_label - 1: + break + else: + # if there are no weird indices, keep the tree + return tree + + sorted_indices = sorted(indices) + if indices == sorted_indices: + raise ValueError("Skipped index! This should already be accounted for {}".format(tree)) + + if word_idx == 0: + return None + + prefix = words[0].split("_", 1)[0] + prev_idx = word_idx - 1 + prev_label = indices[prev_idx] + missing_words = ["%s_%d" % (prefix, x) for x in range(prev_label + 1, word_label)] + missing_words = "|".join(missing_words) + #move_tregex = "%s > (__=home > (__=parent > __=grandparent)) . (%s > (__=move > =grandparent))" % (words[word_idx], "|".join(missing_words)) + move_tregex = "%s > (__=home > (__=parent << %s $+ (__=move <<, %s <<- %s)))" % (words[word_idx], words[prev_idx], missing_words, missing_words) + move_tsurgeon = "move move $+ home" + modified = tsurgeon_processor.process(tree, move_tregex, move_tsurgeon)[0] + if modified == tree: + # this only happens if the desired fix didn't happen + #print("Failed to process:\n {}\n {} {}".format(tree, prev_label, word_label)) + return None + + tree = modified + +def replace_words(tree, words): + """ + Remap the leaf words given a map of the labels we expect in the leaves + """ + leaves = tree.leaf_labels() + new_words = [words[w].word for w in leaves] + new_tree = tree.replace_words(new_words) + return new_tree + +def process_tree(sentence): + """ + Convert a single ET element representing a Tiger tree to a parse tree + """ + sentence = sentence.getroot() + sent_id = sentence.get("id") + if sent_id is None: + raise ValueError("Tree {} does not have an id".format(sent_idx)) + if len(sentence) > 1: + raise ValueError("Longer than expected number of items in {}".format(sent_id)) + graph = sentence.find("graph") + if not graph: + raise ValueError("Unexpected tree structure in {} : top tag is not 'graph'".format(sent_id)) + + root_id = graph.get("root") + if not root_id: + raise ValueError("Tree has no root id in {}".format(sent_id)) + + terminals = graph.find("terminals") + if not terminals: + raise ValueError("No terminals in tree {}".format(sent_id)) + # some Arboretum graphs have two sets of nonterminals, + # apparently intentionally, so we ignore that possible error + nonterminals = graph.find("nonterminals") + if not nonterminals: + raise ValueError("No nonterminals in tree {}".format(sent_id)) + + # read the words. the words have ids, text, and tags which we care about + words = {} + for word in terminals: + if word.tag == 'parentes-udeladt' or word.tag == 'note': + continue + if word.tag != "t": + raise ValueError("Unexpected tree structure in {} : word with tag other than t".format(sent_id)) + word_id = word.get("id") + if not word_id: + raise ValueError("Word had no id in {}".format(sent_id)) + word_text = word.get("word") + if not word_text: + raise ValueError("Word had no text in {}".format(sent_id)) + word_pos = word.get("pos") + if not word_pos: + raise ValueError("Word had no pos in {}".format(sent_id)) + words[word_id] = Word(word_text, word_pos) + + # read the nodes. the nodes have ids, labels, and children + # some of the edges are labeled "secedge". we ignore those + nodes = {} + for nt in nonterminals: + if nt.tag != "nt": + raise ValueError("Unexpected tree structure in {} : node with tag other than nt".format(sent_id)) + nt_id = nt.get("id") + if not nt_id: + raise ValueError("NT has no id in {}".format(sent_id)) + nt_label = nt.get("cat") + if not nt_label: + raise ValueError("NT has no label in {}".format(sent_id)) + + children = [] + for child in nt: + if child.tag != "edge" and child.tag != "secedge": + raise ValueError("NT has unexpected child in {} : {}".format(sent_id, child.tag)) + if child.tag == "edge": + child_id = child.get("idref") + if not child_id: + raise ValueError("Child is missing an id in {}".format(sent_id)) + children.append(child_id) + nodes[nt_id] = Node(nt_label, children) + + if root_id not in nodes: + raise ValueError("Could not find root in nodes in {}".format(sent_id)) + + tree = process_nodes(root_id, words, nodes, set()) + return tree, words + +def word_sequence_missing_words(tree): + """ + Check if the word sequence is missing words + + Some trees skip labels, such as + (s (fcl (pron-pers s16817_1) (v-fin s16817_2) (prp s16817_3) (pp (prp s16817_5) (par (n s16817_6) (conj-c s16817_7) (n s16817_8))) (pu s16817_9))) + but in these cases, the word is present in the original text and simply not attached to the tree + """ + words = tree.leaf_labels() + indices = [int(w.split("_")[1]) for w in words] + indices = sorted(indices) + for idx, label in enumerate(indices): + if label != idx + 1: + return True + return False + +WORD_TO_PHRASE = { + "art": "advp", # "en smule" is the one time this happens. it is used as an advp elsewhere + "adj": "adjp", + "adv": "advp", + "conj": "cp", + "intj": "fcl", # not sure? seems to match "hold kæft" when it shows up + "n": "np", + "num": "np", # would prefer something like QP from PTB + "pron": "np", # ?? + "prop": "np", + "prp": "pp", + "v": "vp", +} + +def split_underscores(tree): + assert not tree.is_leaf(), "Should never reach a leaf in this code path" + + if tree.is_preterminal(): + return tree + + children = tree.children + new_children = [] + for child in children: + if child.is_preterminal(): + if '_' not in child.children[0].label: + new_children.append(child) + continue + + if child.label.split("-")[0] not in WORD_TO_PHRASE: + raise ValueError("SPLITTING {}".format(child)) + pieces = [] + for piece in child.children[0].label.split("_"): + # This may not be accurate, but we already retag the treebank anyway + if len(piece) == 0: + raise ValueError("A word started or ended with _") + pieces.append(Tree(child.label, Tree(piece))) + new_children.append(Tree(WORD_TO_PHRASE[child.label.split("-")[0]], pieces)) + else: + new_children.append(split_underscores(child)) + + return Tree(tree.label, new_children) + +REMAP_LABELS = { + "adj": "adjp", + "adv": "advp", + "intj": "fcl", + "n": "np", + "num": "np", # again, a dedicated number node would be better, but there are only a few "num" labeled + "prp": "pp", +} + + +def has_weird_constituents(tree): + """ + Eliminate a few trees with weird labels + + Eliminate p? there are only 3 and they have varying structure underneath + Also cl, since I have no idea how to label it and it only excludes 1 tree + """ + labels = Tree.get_unique_constituent_labels(tree) + if "p" in labels or "cl" in labels: + return True + return False + +def convert_tiger_treebank(input_filename): + sentences = read_xml_file(input_filename) + + unfixable = 0 + dangling = 0 + broken_links = 0 + missing_words = 0 + weird_constituents = 0 + trees = [] + + with tsurgeon.Tsurgeon() as tsurgeon_processor: + for sent_idx, sentence in enumerate(tqdm(sentences)): + try: + tree, words = process_tree(sentence) + + if not tree.all_leaves_are_preterminals(): + dangling += 1 + continue + + if word_sequence_missing_words(tree): + missing_words += 1 + continue + + tree = check_words(tree, tsurgeon_processor) + if tree is None: + unfixable += 1 + continue + + if has_weird_constituents(tree): + weird_constituents += 1 + continue + + tree = replace_words(tree, words) + tree = split_underscores(tree) + tree = tree.remap_constituent_labels(REMAP_LABELS) + trees.append(tree) + except BrokenLinkError as e: + # the get("id") would have failed as a different error type if missing, + # so we can safely use it directly like this + broken_links += 1 + # print("Unable to process {} because of broken links: {}".format(sentence.getroot().get("id"), e)) + + print("Found {} trees with empty nodes".format(dangling)) + print("Found {} trees with unattached words".format(missing_words)) + print("Found {} trees with confusing constituent labels".format(weird_constituents)) + print("Not able to rearrange {} nodes".format(unfixable)) + print("Unable to handle {} trees because of broken links, eg names in another tree".format(broken_links)) + print("Parsed {} trees from {}".format(len(trees), input_filename)) + return trees + +def main(): + treebank = convert_tiger_treebank("extern_data/constituency/danish/W0084/arboretum.tiger/arboretum.tiger") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/convert_icepahc.py b/stanza/stanza/utils/datasets/constituency/convert_icepahc.py new file mode 100644 index 0000000000000000000000000000000000000000..448e6dd03ae2fc3f0b2849d08fbf2780f81a9f0d --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_icepahc.py @@ -0,0 +1,83 @@ +""" +Currently this doesn't function + +The goal is simply to demonstrate how to use tsurgeon +""" + +from stanza.models.constituency.tree_reader import read_trees, read_treebank +from stanza.server import tsurgeon + +TREEBANK = """ +( (IP-MAT (NP-SBJ (PRO-N Það-það)) + (BEPI er-vera) + (ADVP (ADV eiginlega-eiginlega)) + (ADJP (NEG ekki-ekki) (ADJ-N hægt-hægur)) + (IP-INF (TO að-að) (VB lýsa-lýsa)) + (NP-OB1 (N-D tilfinningu$-tilfinning) (D-D $nni-hinn)) + (IP-INF (TO að-að) (VB fá-fá)) + (IP-INF (TO að-að) (VB taka-taka)) + (NP-OB1 (N-A þátt-þáttur)) + (PP (P í-í) + (NP (D-D þessu-þessi))) + (, ,-,) + (VBPI segir-segja) + (NP-SBJ (NPR-N Sverrir-sverrir) (NPR-N Ingi-ingi)) + (. .-.))) +""" + +# Output of the first tsurgeon: +#(ROOT +# (IP-MAT +# (NP-SBJ (PRO-N Það)) +# (BEPI er) +# (ADVP (ADV eiginlega)) +# (ADJP (NEG ekki) (ADJ-N hægt)) +# (IP-INF (TO að) (VB lýsa)) +# (NP-OB1 (N-D tilfinningu$) (D-D $nni)) +# (IP-INF (TO að) (VB fá)) +# (IP-INF (TO að) (VB taka)) +# (NP-OB1 (N-A þátt)) +# (PP +# (P í) +# (NP (D-D þessu))) +# (, ,) +# (VBPI segir) +# (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi)) +# (. .))) + +# Output of the second operation +#(ROOT +# (IP-MAT +# (NP-SBJ (PRO-N Það)) +# (BEPI er) +# (ADVP (ADV eiginlega)) +# (ADJP (NEG ekki) (ADJ-N hægt)) +# (IP-INF (TO að) (VB lýsa)) +# (NP-OB1 (N-D tilfinningunni)) +# (IP-INF (TO að) (VB fá)) +# (IP-INF (TO að) (VB taka)) +# (NP-OB1 (N-A þátt)) +# (PP +# (P í) +# (NP (D-D þessu))) +# (, ,) +# (VBPI segir) +# (NP-SBJ (NPR-N Sverrir) (NPR-N Ingi)) +# (. .))) + + +treebank = read_trees(TREEBANK) + +with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor: + form_tregex = "/^(.+)-.+$/#1%form=word !< __" + form_tsurgeon = "relabel word /^.+$/%{form}/" + + noun_det_tregex = "/^N-/ < /^([^$]+)[$]$/#1%noun=noun $+ (/^D-/ < /^[$]([^$]+)$/#1%det=det)" + noun_det_relabel = "relabel noun /^.+$/%{noun}%{det}/" + noun_det_prune = "prune det" + + for tree in treebank: + updated_tree = tsurgeon_processor.process(tree, (form_tregex, form_tsurgeon))[0] + print("{:P}".format(updated_tree)) + updated_tree = tsurgeon_processor.process(updated_tree, (noun_det_tregex, noun_det_relabel, noun_det_prune))[0] + print("{:P}".format(updated_tree)) diff --git a/stanza/stanza/utils/datasets/constituency/convert_it_turin.py b/stanza/stanza/utils/datasets/constituency/convert_it_turin.py new file mode 100644 index 0000000000000000000000000000000000000000..7662490f2d94146bcc021a31be0e2d783011f1e1 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_it_turin.py @@ -0,0 +1,339 @@ +""" +Converts Turin's constituency dataset + +Turin University put out a freely available constituency dataset in 2011. +It is not as large as VIT or ISST, but it is free, which is nice. + +The 2011 parsing task combines trees from several sources: +http://www.di.unito.it/~tutreeb/evalita-parsingtask-11.html + +There is another site for Turin treebanks: +http://www.di.unito.it/~tutreeb/treebanks.html + +Weirdly, the most recent versions of the Evalita trees are not there. +The most relevant parts are the ParTUT downloads. As of Sep. 2021: + +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/JRCAcquis_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/UDHR_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/CC_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/FB_It.pen +http://www.di.unito.it/~tutreeb/corpora/Par-TUT/tutINpenn/italian/WIT3_It.pen + +We can't simply cat all these files together as there are a bunch of +asterisks as comments and the files may have some duplicates. For +example, the JRCAcquis piece has many duplicates. Also, some don't +pass validation for one reason or another. + +One oddity of these data files is that the MWT are denoted by doubling +the token. The token is not split as would be expected, though. We try +to use stanza's MWT tokenizer for IT to split the tokens, with some +rules added by hand in BIWORD_SPLITS. Two are still unsplit, though... +""" + +import glob +import os +import re +import sys + +import stanza +from stanza.models.constituency import parse_tree +from stanza.models.constituency import tree_reader + +def load_without_asterisks(in_file, encoding='utf-8'): + with open(in_file, encoding=encoding) as fin: + new_lines = [x if x.find("********") < 0 else "\n" for x in fin.readlines()] + if len(new_lines) > 0 and not new_lines[-1].endswith("\n"): + new_lines[-1] = new_lines[-1] + "\n" + return new_lines + +CONSTITUENT_SPLIT = re.compile("[-=#+0-9]") + +# JRCA is almost entirely duplicates +# WIT3 follows a different annotation scheme +FILES_TO_ELIMINATE = ["JRCAcquis_It.pen", "WIT3_It.pen"] + +# assuming this is a typo +REMAP_NODES = { "Sbar" : "SBAR" } + +REMAP_WORDS = { "-LSB-": "[", "-RSB-": "]" } + +# these mostly seem to be mistakes +# maybe Vbar and ADVbar should be converted to something else? +NODES_TO_ELIMINATE = ["C", "PHRASP", "PRDT", "Vbar", "parte", "ADVbar"] + +UNKNOWN_SPLITS = set() + +# a map of splits that the tokenizer or MWT doesn't handle well +BIWORD_SPLITS = { "offertogli": ("offerto", "gli"), + "offertegli": ("offerte", "gli"), + "formatasi": ("formata", "si"), + "formatosi": ("formato", "si"), + "multiplexarlo": ("multiplexar", "lo"), + "esibirsi": ("esibir", "si"), + "pagarne": ("pagar", "ne"), + "recarsi": ("recar", "si"), + "trarne": ("trar", "ne"), + "esserci": ("esser", "ci"), + "aprirne": ("aprir", "ne"), + "farle": ("far", "le"), + "disporne": ("dispor", "ne"), + "andargli": ("andar", "gli"), + "CONSIDERARSI": ("CONSIDERAR", "SI"), + "conferitegli": ("conferite", "gli"), + "formatasi": ("formata", "si"), + "formatosi": ("formato", "si"), + "Formatisi": ("Formati", "si"), + "multiplexarlo": ("multiplexar", "lo"), + "esibirsi": ("esibir", "si"), + "pagarne": ("pagar", "ne"), + "recarsi": ("recar", "si"), + "trarne": ("trar", "ne"), + "temerne": ("temer", "ne"), + "esserci": ("esser", "ci"), + "esservi": ("esser", "vi"), + "restituirne": ("restituir", "ne"), + "col": ("con", "il"), + "cogli": ("con", "gli"), + "dirgli": ("dir", "gli"), + "opporgli": ("oppor", "gli"), + "eccolo": ("ecco", "lo"), + "Eccolo": ("Ecco", "lo"), + "Eccole": ("Ecco", "le"), + "farci": ("far", "ci"), + "farli": ("far", "li"), + "farne": ("far", "ne"), + "farsi": ("far", "si"), + "farvi": ("far", "vi"), + "Connettiti": ("Connetti", "ti"), + "APPLICARSI": ("APPLICAR", "SI"), + # This is not always two words, but if it IS two words, + # it gets split like this + "assicurati": ("assicura", "ti"), + "Fatti": ("Fai", "te"), + "ai": ("a", "i"), + "Ai": ("A", "i"), + "AI": ("A", "I"), + "al": ("a", "il"), + "Al": ("A", "il"), + "AL": ("A", "IL"), + "coi": ("con", "i"), + "colla": ("con", "la"), + "colle": ("con", "le"), + "dal": ("da", "il"), + "Dal": ("Da", "il"), + "DAL": ("DA", "IL"), + "dei": ("di", "i"), + "Dei": ("Di", "i"), + "DEI": ("DI", "I"), + "del": ("di", "il"), + "Del": ("Di", "il"), + "DEL": ("DI", "IL"), + "nei": ("in", "i"), + "NEI": ("IN", "I"), + "nel": ("in", "il"), + "Nel": ("In", "il"), + "NEL": ("IN", "IL"), + "pel": ("per", "il"), + "sui": ("su", "i"), + "Sui": ("Su", "i"), + "sul": ("su", "il"), + "Sul": ("Su", "il"), + ",": (",", ","), + ".": (".", "."), + '"': ('"', '"'), + '-': ('-', '-'), + '-LRB-': ('-LRB-', '-LRB-'), + "garantirne": ("garantir", "ne"), + "aprirvi": ("aprir", "vi"), + "esimersi": ("esimer", "si"), + "opporsi": ("oppor", "si"), +} + +CAP_BIWORD = re.compile("[A-Z]+_[A-Z]+") + +def split_mwe(tree, pipeline): + words = list(tree.leaf_labels()) + found = False + for idx, word in enumerate(words[:-3]): + if word == words[idx+1] and word == words[idx+2] and word == words[idx+3]: + raise ValueError("Oh no, 4 consecutive words") + + for idx, word in enumerate(words[:-2]): + if word == words[idx+1] and word == words[idx+2]: + doc = pipeline(word) + assert len(doc.sentences) == 1 + if len(doc.sentences[0].words) != 3: + raise RuntimeError("Word {} not tokenized into 3 parts... thought all 3 part words were handled!".format(word)) + words[idx] = doc.sentences[0].words[0].text + words[idx+1] = doc.sentences[0].words[1].text + words[idx+2] = doc.sentences[0].words[2].text + found = True + + for idx, word in enumerate(words[:-1]): + if word == words[idx+1]: + if word in BIWORD_SPLITS: + first_word = BIWORD_SPLITS[word][0] + second_word = BIWORD_SPLITS[word][1] + elif CAP_BIWORD.match(word): + first_word, second_word = word.split("_") + else: + doc = pipeline(word) + assert len(doc.sentences) == 1 + if len(doc.sentences[0].words) == 2: + first_word = doc.sentences[0].words[0].text + second_word = doc.sentences[0].words[1].text + else: + if word not in UNKNOWN_SPLITS: + UNKNOWN_SPLITS.add(word) + print("Could not figure out how to split {}\n {}\n {}".format(word, " ".join(words), tree)) + continue + + words[idx] = first_word + words[idx+1] = second_word + found = True + + if found: + tree = tree.replace_words(words) + return tree + + +def load_trees(filename, pipeline): + # some of the files are in latin-1 encoding rather than utf-8 + try: + raw_text = load_without_asterisks(filename, "utf-8") + except UnicodeDecodeError: + raw_text = load_without_asterisks(filename, "latin-1") + + # also, some have messed up validation (it will be logged) + # hence the broken_ok=True argument + trees = tree_reader.read_trees("".join(raw_text), broken_ok=True) + + filtered_trees = [] + for tree in trees: + if tree.children[0].label is None: + print("Skipping a broken tree (missing label) in {}: {}".format(filename, tree)) + continue + + try: + words = tuple(tree.leaf_labels()) + except ValueError: + print("Skipping a broken tree (missing preterminal) in {}: {}".format(filename, tree)) + continue + + if any('www.facebook' in pt.label for pt in tree.preterminals()): + print("Skipping a tree with a weird preterminal label in {}: {}".format(filename, tree)) + continue + + tree = tree.prune_none().simplify_labels(CONSTITUENT_SPLIT) + + if len(tree.children) > 1: + print("Found a tree with a non-unary root! {}: {}".format(filename, tree)) + continue + if tree.children[0].is_preterminal(): + print("Found a tree with a single preterminal node! {}: {}".format(filename, tree)) + continue + + # The expectation is that the retagging will handle this anyway + for pt in tree.preterminals(): + if not pt.label: + pt.label = "UNK" + print("Found a tree with a blank preterminal label. Setting it to UNK. {}: {}".format(filename, tree)) + + tree = tree.remap_constituent_labels(REMAP_NODES) + tree = tree.remap_words(REMAP_WORDS) + + tree = split_mwe(tree, pipeline) + if tree is None: + continue + + constituents = set(parse_tree.Tree.get_unique_constituent_labels(tree)) + for weird_label in NODES_TO_ELIMINATE: + if weird_label in constituents: + break + else: + weird_label = None + if weird_label is not None: + print("Skipping a tree with a weird label {} in {}: {}".format(weird_label, filename, tree)) + continue + + filtered_trees.append(tree) + + return filtered_trees + +def save_trees(out_file, trees): + print("Saving {} trees to {}".format(len(trees), out_file)) + with open(out_file, "w", encoding="utf-8") as fout: + for tree in trees: + fout.write(str(tree)) + fout.write("\n") + +def convert_it_turin(input_path, output_path): + pipeline = stanza.Pipeline("it", processors="tokenize, mwt", tokenize_no_ssplit=True) + + os.makedirs(output_path, exist_ok=True) + + evalita_dir = os.path.join(input_path, "evalita") + + evalita_test = os.path.join(evalita_dir, "evalita11_TESTgold_CONPARSE.penn") + it_test = os.path.join(output_path, "it_turin_test.mrg") + test_trees = load_trees(evalita_test, pipeline) + save_trees(it_test, test_trees) + + known_text = set() + for tree in test_trees: + words = tuple(tree.leaf_labels()) + assert words not in known_text + known_text.add(words) + + evalita_train = os.path.join(output_path, "it_turin_train.mrg") + evalita_files = glob.glob(os.path.join(evalita_dir, "*2011*penn")) + turin_files = glob.glob(os.path.join(input_path, "turin", "*pen")) + filenames = evalita_files + turin_files + filtered_trees = [] + for filename in filenames: + if os.path.split(filename)[1] in FILES_TO_ELIMINATE: + continue + + trees = load_trees(filename, pipeline) + file_trees = [] + + for tree in trees: + words = tuple(tree.leaf_labels()) + if words in known_text: + print("Skipping a duplicate in {}: {}".format(filename, tree)) + continue + + known_text.add(words) + + file_trees.append(tree) + + filtered_trees.append((filename, file_trees)) + + print("{} contains {} usable trees".format(evalita_test, len(test_trees))) + print(" Unique constituents in {}: {}".format(evalita_test, parse_tree.Tree.get_unique_constituent_labels(test_trees))) + + train_trees = [] + dev_trees = [] + for filename, file_trees in filtered_trees: + print("{} contains {} usable trees".format(filename, len(file_trees))) + print(" Unique constituents in {}: {}".format(filename, parse_tree.Tree.get_unique_constituent_labels(file_trees))) + for tree in file_trees: + if len(train_trees) <= len(dev_trees) * 9: + train_trees.append(tree) + else: + dev_trees.append(tree) + + it_train = os.path.join(output_path, "it_turin_train.mrg") + save_trees(it_train, train_trees) + + it_dev = os.path.join(output_path, "it_turin_dev.mrg") + save_trees(it_dev, dev_trees) + +def main(): + input_path = sys.argv[1] + output_path = sys.argv[2] + + convert_it_turin(input_path, output_path) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/convert_it_vit.py b/stanza/stanza/utils/datasets/constituency/convert_it_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..70b2c3c4088c3f3601f7411d8c58183932fadf0f --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_it_vit.py @@ -0,0 +1,700 @@ +"""Converts the proprietary VIT dataset to a format suitable for stanza + +There are multiple corrections in the UD version of VIT, along with +recommended splits for the MWT, along with recommended splits of +the sentences into train/dev/test + +Accordingly, it is necessary to use the UD dataset as a reference + +Here is a sample line of the text file we use: + +#ID=sent_00002 cp-[sp-[part-negli, sn-[sa-[ag-ultimi], nt-anni]], f-[sn-[art-la, n-dinamica, spd-[partd-dei, sn-[n-polo_di_attrazione]]], ibar-[ause-è, ausep-stata, savv-[savv-[avv-sempre], avv-più], vppt-caratterizzata], compin-[spda-[partda-dall, sn-[n-emergere, spd-[pd-di, sn-[art-una, sa-[ag-crescente], n-concorrenza, f2-[rel-che, f-[ibar-[clit-si, ause-è, avv-progressivamente, vppin-spostata], compin-[spda-[partda-dalle, sn-[sa-[ag-singole], n-imprese]], sp-[part-ai, sn-[n-sistemi, sa-[coord-[ag-economici, cong-e, ag-territoriali]]]], fp-[punt-',', sv5-[vgt-determinando, compt-[sn-[art-l_, nf-esigenza, spd-[pd-di, sn-[art-una, n-riconsiderazione, spd-[partd-dei, sn-[n-rapporti, sv3-[ppre-esistenti, compin-[sp-[p-tra, sn-[n-soggetti, sa-[ag-produttivi]]], cong-e, sn-[n-ambiente, f2-[sp-[p-in, sn-[relob-cui]], f-[sn-[deit-questi], ibar-[vin-operano, punto-.]]]]]]]]]]]]]]]]]]]]]]]] + +Here you can already see multiple issues when parsing: +- the first word is "negli", which is split into In_ADP gli_DET in the UD version +- also the first word is capitalized in the UD version +- comma looks like a tempting split target, but there is a ',' in this sentence + punt-',' +- not shown here is '-' which is different from the - used for denoting POS + par-'-' + +Fortunately, -[ is always an open and ] is always a close + +As of April 2022, the UD version of the dataset has some minor edits +which are necessary for the proper functioning of this script. +Otherwise, the MWT won't align correctly, some typos won't be +corrected, etc. These edits are released in UD 2.10 + +The data itself is available from ELRA: + +http://catalog.elra.info/en-us/repository/browse/ELRA-W0040/ + +Internally at Stanford you can contact Chris Manning or John Bauer. + +The processing goes as follows: +- read in UD and con trees + some of the con trees have broken brackets and are discarded + in other cases, abbreviations were turned into single tokens in UD +- extract the MWT expansions of Italian contractions, + such as Negli -> In gli +- attempt to align the trees from the two datasets using ngrams + some trees had the sentence splitting updated + sentences which can't be matched are discarded +- use CoreNLP tsurgeon to update tokens in the con trees + based on the information in the UD dataset + - split contractions + - rearrange clitics which are occasionally non-projective +- replace the words in the con tree with the dep tree's words + this takes advantage of spelling & capitalization fixes + +In 2022, there was an update to the dataset from Prof. Delmonte. +This update is hopefully in current ELRA distributions now. +If not, please contact ELRA to specifically ask for the updated version. +Internally to Stanford, feel free to ask Chris or John for the updates. +Look for the line below "original version with more errors" + +In August 2022, Prof. Delmonte made a slight update in a zip file +`john.zip`. If/when that gets updated to ELRA, we will update it +here. Contact Chris or John for a copy if not updated yet, or go +back in git history to get the older version of the code which +works with the 2022 ELRA update. + +Later, in September 2022, there is yet another update, +New version of VIT.zip +Unzip the contents into a folder +$CONSTITUENCY_BASE/italian/it_vit +so there should be a file +$CONSTITUENCY_BASE/italian/it_vit/VITwritten/VITconstsyntNumb + +There are a few other updates needed to improve the annotations, +but all the nagging seemed to give Prof. Delmonte a headache, +so at this point we include those fixes in this script instead. +See the first few tsurgeon operations in update_mwts_and_special_cases +""" + +from collections import defaultdict, deque, namedtuple +import itertools +import os +import re +import sys + +from tqdm import tqdm + +from stanza.models.constituency.tree_reader import read_trees, UnclosedTreeError, ExtraCloseTreeError +from stanza.server import tsurgeon +from stanza.utils.conll import CoNLL +from stanza.utils.datasets.constituency.utils import SHARDS, write_dataset +import stanza.utils.default_paths as default_paths + +def read_constituency_sentences(fin): + """ + Reads the lines from the constituency treebank and splits into ID, text + + No further processing is done on the trees yet + """ + sentences = [] + for line in fin: + line = line.strip() + # WTF why doesn't strip() remove this + line = line.replace(u'\ufeff', '') + if not line: + continue + sent_id, sent_text = line.split(maxsplit=1) + # we have seen a couple different versions of this sentence header + # although one file is always consistent with itself, at least + if not sent_id.startswith("#ID=sent") and not sent_id.startswith("ID#sent"): + raise ValueError("Unexpected start of sentence: |{}|".format(sent_id)) + if not sent_text: + raise ValueError("Empty text for |{}|".format(sent_id)) + sentences.append((sent_id, sent_text)) + return sentences + +def read_constituency_file(filename): + print("Reading raw constituencies from %s" % filename) + with open(filename, encoding='utf-8') as fin: + return read_constituency_sentences(fin) + +OPEN = "-[" +CLOSE = "]" + +DATE_RE = re.compile("^([0-9]{1,2})[_]([0-9]{2})$") +INTEGER_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[%]$") +DECIMAL_PERCENT_RE = re.compile(r"^((?:min|plus)?[0-9]{1,3})[/_]([0-9]{1,3})[%]$") +RANGE_PERCENT_RE = re.compile(r"^([0-9]{1,2}[/_][0-9]{1,2})[/]([0-9]{1,2}[/_][0-9]{1,2})[%]$") +DECIMAL_RE = re.compile(r"^([0-9])[_]([0-9])$") + +ProcessedTree = namedtuple('ProcessedTree', ['con_id', 'dep_id', 'tree']) + +def raw_tree(text): + """ + A sentence will look like this: + #ID=sent_00001 fc-[f3-[sn-[art-le, n-infrastrutture, sc-[ccom-come, sn-[n-fattore, spd-[pd-di, + sn-[n-competitività]]]]]], f3-[spd-[pd-di, sn-[mw-Angela, nh-Airoldi]]], punto-.] + Non-preterminal nodes have tags, followed by the stuff under the node, -[ + The node is closed by the ] + """ + pieces = [] + open_pieces = text.split(OPEN) + for open_idx, open_piece in enumerate(open_pieces): + if open_idx > 0: + pieces[-1] = pieces[-1] + OPEN + open_piece = open_piece.strip() + if not open_piece: + raise ValueError("Unexpected empty node!") + close_pieces = open_piece.split(CLOSE) + for close_idx, close_piece in enumerate(close_pieces): + if close_idx > 0: + pieces.append(CLOSE) + close_piece = close_piece.strip() + if not close_piece: + # this is okay - multiple closes at the end of a deep bracket + continue + word_pieces = close_piece.split(", ") + pieces.extend([x.strip() for x in word_pieces if x.strip()]) + + # at this point, pieces is a list with: + # tag-[ for opens + # tag-word for words + # ] for closes + # this structure converts pretty well to reading using the tree reader + + PIECE_MAPPING = { + "agn-/ter'": "(agn ter)", + "cong-'&'": "(cong &)", + "da_riempire-'...'": "(da_riempire ...)", + "date-1992_1993": "(date 1992/1993)", + "date-'31-12-95'": "(date 31-12-95)", + "date-'novantaquattro-95'":"(date novantaquattro-95)", + "date-'novantaquattro-95": "(date novantaquattro-95)", + "date-'novantaquattro-novantacinque'": "(date novantaquattro-novantacinque)", + "dirs-':'": "(dirs :)", + "dirs-'\"'": "(dirs \")", + "mw-'&'": "(mw &)", + "mw-'Presunto'": "(mw Presunto)", + "nh-'Alain-Gauze'": "(nh Alain-Gauze)", + "np-'porto_Marghera'": "(np Porto) (np Marghera)", + "np-'roma-l_aquila'": "(np Roma-L'Aquila)", + "np-'L_Aquila-Villa_Vomano'": "(np L'Aquila) (np -) (np Villa) (np Vomano)", + "npro-'Avanti_!'": "(npro Avanti_!)", + "npro-'Viacom-Paramount'": "(npro Viacom-Paramount)", + "npro-'Rhone-Poulenc'": "(npro Rhone-Poulenc)", + "npro-'Itar-Tass'": "(npro Itar-Tass)", + "par-(-)": "(par -)", + "par-','": "(par ,)", + "par-'<'": "(par <)", + "par-'>'": "(par >)", + "par-'-'": "(par -)", + "par-'\"'": "(par \")", + "par-'('": "(par -LRB-)", + "par-')'": "(par -RRB-)", + "par-'&&'": "(par &&)", + "punt-','": "(punt ,)", + "punt-'-'": "(punt -)", + "punt-';'": "(punt ;)", + "punto-':'": "(punto :)", + "punto-';'": "(punto ;)", + "puntint-'!'": "(puntint !)", + "puntint-'?'": "(puntint !)", + "num-'2plus2'": "(num 2+2)", + "num-/bis'": "(num bis)", + "num-/ter'": "(num ter)", + "num-18_00/1_00": "(num 18:00/1:00)", + "num-1/500_2/000": "(num 1.500-2.000)", + "num-16_1": "(num 16,1)", + "num-0_1": "(num 0,1)", + "num-0_3": "(num 0,3)", + "num-2_7": "(num 2,7)", + "num-455_68": "(num 455/68)", + "num-437_5": "(num 437,5)", + "num-4708_82": "(num 4708,82)", + "num-16EQ517_7": "(num 16EQ517/7)", + "num-2=184_90": "(num 2=184/90)", + "num-3EQ429_20": "(num 3eq429/20)", + "num-'1990-EQU-100'": "(num 1990-EQU-100)", + "num-'500-EQU-250'": "(num 500-EQU-250)", + "num-0_39%minus": "(num 0,39) (num %%) (num -)", + "num-1_88/76": "(num 1-88/76)", + "num-'70/80'": "(num 70,80)", + "num-'18/20'": "(num 18:20)", + "num-295/mila'": "(num 295mila)", + "num-'295/mila'": "(num 295mila)", + "num-0/07%plus": "(num 0,07) (num %%) (num plus)", + "num-0/69%minus": "(num 0,69) (num %%) (num minus)", + "num-0_39%minus": "(num 0,39) (num %%) (num minus)", + "num-9_11/16": "(num 9-11,16)", + "num-2/184_90": "(num 2=184/90)", + "num-3/429_20": "(num 3eq429/20)", + # TODO: remove the following num conversions if possible + # this would require editing either constituency or UD + "num-1:28_124": "(num 1=8/1242)", + "num-1:28_397": "(num 1=8/3972)", + "num-1:28_947": "(num 1=8/9472)", + "num-1:29_657": "(num 1=9/6572)", + "num-1:29_867": "(num 1=9/8672)", + "num-1:29_874": "(num 1=9/8742)", + "num-1:30_083": "(num 1=0/0833)", + "num-1:30_140": "(num 1=0/1403)", + "num-1:30_354": "(num 1=0/3543)", + "num-1:30_453": "(num 1=0/4533)", + "num-1:30_946": "(num 1=0/9463)", + "num-1:31_602": "(num 1=1/6023)", + "num-1:31_842": "(num 1=1/8423)", + "num-1:32_087": "(num 1=2/0873)", + "num-1:32_259": "(num 1=2/2593)", + "num-1:33_166": "(num 1=3/1663)", + "num-1:34_154": "(num 1=4/1543)", + "num-1:34_556": "(num 1=4/5563)", + "num-1:35_323": "(num 1=5/3233)", + "num-1:36_023": "(num 1=6/0233)", + "num-1:36_076": "(num 1=6/0763)", + "num-1:36_651": "(num 1=6/6513)", + "n-giga_flop/s": "(n giga_flop/s)", + "sect-'g-1'": "(sect g-1)", + "sect-'h-1'": "(sect h-1)", + "sect-'h-2'": "(sect h-2)", + "sect-'h-3'": "(sect h-3)", + "abbr-'a-b-c'": "(abbr a-b-c)", + "abbr-d_o_a_": "(abbr DOA)", + "abbr-d_l_": "(abbr DL)", + "abbr-i_s_e_f_": "(abbr ISEF)", + "abbr-d_p_r_": "(abbr DPR)", + "abbr-D_P_R_": "(abbr DPR)", + "abbr-d_m_": "(abbr dm)", + "abbr-T_U_": "(abbr TU)", + "abbr-F_A_M_E_": "(abbr Fame)", + "dots-'...'": "(dots ...)", + } + new_pieces = ["(ROOT "] + for piece in pieces: + if piece.endswith(OPEN): + new_pieces.append("(" + piece[:-2]) + elif piece == CLOSE: + new_pieces.append(")") + elif piece in PIECE_MAPPING: + new_pieces.append(PIECE_MAPPING[piece]) + else: + # maxsplit=1 because of words like 1990-EQU-100 + tag, word = piece.split("-", maxsplit=1) + if word.find("'") >= 0 or word.find("(") >= 0 or word.find(")") >= 0: + raise ValueError("Unhandled weird node: {}".format(piece)) + if word.endswith("_"): + word = word[:-1] + "'" + date_match = DATE_RE.match(word) + if date_match: + # 10_30 special case sent_07072 + # 16_30 special case sent_07098 + # 21_15 special case sent_07099 and others + word = date_match.group(1) + ":" + date_match.group(2) + integer_percent = INTEGER_PERCENT_RE.match(word) + if integer_percent: + word = integer_percent.group(1) + "_%%" + range_percent = RANGE_PERCENT_RE.match(word) + if range_percent: + word = range_percent.group(1) + "," + range_percent.group(2) + "_%%" + percent = DECIMAL_PERCENT_RE.match(word) + if percent: + word = percent.group(1) + "," + percent.group(2) + "_%%" + decimal = DECIMAL_RE.match(word) + if decimal: + word = decimal.group(1) + "," + decimal.group(2) + # there are words which are multiple words mashed together + # with _ for some reason + # also, words which end in ' are replaced with _ + # fortunately, no words seem to have both + # splitting like this means the tags are likely wrong, + # but the conparser needs to retag anyway, so it shouldn't matter + word_pieces = word.split("_") + for word_piece in word_pieces: + new_pieces.append("(%s %s)" % (tag, word_piece)) + new_pieces.append(")") + + text = " ".join(new_pieces) + trees = read_trees(text) + if len(trees) > 1: + raise ValueError("Unexpected number of trees!") + return trees[0] + +def extract_ngrams(sentence, process_func, ngram_len=4): + leaf_words = [x for x in process_func(sentence)] + leaf_words = ["l'" if x == "l" else x for x in leaf_words] + if len(leaf_words) <= ngram_len: + return [tuple(leaf_words)] + its = [leaf_words[i:i+len(leaf_words)-ngram_len+1] for i in range(ngram_len)] + return [words for words in itertools.zip_longest(*its)] + +def build_ngrams(sentences, process_func, id_func, ngram_len=4): + """ + Turn the list of processed trees into a bunch of ngrams + + The returned map is from tuple to set of ids + + The idea being that this map can be used to search for trees to + match datasets + """ + ngram_map = defaultdict(set) + for sentence in tqdm(sentences, postfix="Extracting ngrams"): + sentence_id = id_func(sentence) + ngrams = extract_ngrams(sentence, process_func, ngram_len) + for ngram in ngrams: + ngram_map[ngram].add(sentence_id) + return ngram_map + +# just the tokens (maybe use words? depends on MWT in the con dataset) +DEP_PROCESS_FUNC = lambda x: [t.text.lower() for t in x.tokens] +# find the comment with "sent_id" in it, take just the id itself +DEP_ID_FUNC = lambda x: [c for c in x.comments if c.startswith("# sent_id")][0].split()[-1] + +CON_PROCESS_FUNC = lambda x: [y.lower() for y in x.leaf_labels()] + +def match_ngrams(sentence_ngrams, ngram_map, debug=False): + """ + Check if there is a SINGLE matching sentence in the ngram_map for these ngrams + + If an ngram shows up in multiple sentences, that is okay, but we ignore that info + If an ngram shows up in just one sentence, that is considered the match + If a different ngram then shows up in a different sentence, that is a problem + TODO: taking the intersection of all non-empty matches might be better + """ + if debug: + print("NGRAMS FOR DEBUG SENTENCE:") + potential_match = None + unknown_ngram = 0 + for ngram in sentence_ngrams: + con_matches = ngram_map[ngram] + if debug: + print("{} matched {}".format(ngram, len(con_matches))) + if len(con_matches) == 0: + unknown_ngram += 1 + continue + if len(con_matches) > 1: + continue + # get the one & only element from the set + con_match = next(iter(con_matches)) + if debug: + print(" {}".format(con_match)) + if potential_match is None: + potential_match = con_match + elif potential_match != con_match: + return None + if unknown_ngram > len(sentence_ngrams) / 2: + return None + return potential_match + +def match_sentences(con_tree_map, con_vit_ngrams, dep_sentences, split_name, debug_sentence=None): + """ + Match ngrams in the dependency sentences to the constituency sentences + + Then, to make sure the constituency sentence wasn't split into two + in the UD dataset, this checks the ngrams in the reverse direction + + Some examples of things which don't match: + VIT-4769 Insegnanti non vedenti, insegnanti non autosufficienti con protesi agli arti inferiori. + this is duplicated in the original dataset, so the matching algorithm can't possibly work + + VIT-4796 I posti istituiti con attività di sostegno dei docenti che ottengono il trasferimento su classi di concorso; + the correct con match should be sent_04829 but the brackets on that tree are broken + """ + con_to_dep_matches = {} + dep_ngram_map = build_ngrams(dep_sentences, DEP_PROCESS_FUNC, DEP_ID_FUNC) + unmatched = 0 + bad_match = 0 + for sentence in dep_sentences: + sentence_ngrams = extract_ngrams(sentence, DEP_PROCESS_FUNC) + potential_match = match_ngrams(sentence_ngrams, con_vit_ngrams, debug_sentence is not None and DEP_ID_FUNC(sentence) == debug_sentence) + if potential_match is None: + if unmatched < 5: + print("Could not match the following sentence: {} {}".format(DEP_ID_FUNC(sentence), sentence.text)) + unmatched += 1 + continue + if potential_match not in con_tree_map: + raise ValueError("wtf") + con_ngrams = extract_ngrams(con_tree_map[potential_match], CON_PROCESS_FUNC) + reverse_match = match_ngrams(con_ngrams, dep_ngram_map) + if reverse_match is None: + #print("Matched sentence {} to sentence {} but the reverse match failed".format(sentence.text, " ".join(con_tree_map[potential_match].leaf_labels()))) + bad_match += 1 + continue + con_to_dep_matches[potential_match] = reverse_match + print("Failed to match %d sentences and found %d spurious matches in the %s section" % (unmatched, bad_match, split_name)) + return con_to_dep_matches + +EXCEPTIONS = ["gliene", "glielo", "gliela", "eccoci"] + +def get_mwt(*dep_datasets): + """ + Get the ADP/DET MWTs from the UD dataset + + This class of MWT are expanded in the UD but not the constituencies + """ + mwt_map = {} + for dataset in dep_datasets: + for sentence in dataset.sentences: + for token in sentence.tokens: + if len(token.words) == 1: + continue + # words such as "accorgermene" we just skip over + # those are already expanded in the constituency dataset + # TODO: the clitics are actually expanded weirdly, maybe need to compensate for that + if token.words[0].upos in ('VERB', 'AUX') and all(word.upos == 'PRON' for word in token.words[1:]): + continue + if token.text.lower() in EXCEPTIONS: + continue + if len(token.words) != 2 or token.words[0].upos != 'ADP' or token.words[1].upos != 'DET': + raise ValueError("Not sure how to handle this: {}".format(token)) + expansion = (token.words[0].text, token.words[1].text) + if token.text in mwt_map: + if mwt_map[token.text] != expansion: + raise ValueError("Inconsistent MWT: {} -> {} or {}".format(token.text, expansion, mwt_map[token.text])) + continue + #print("Expanding {} to {}".format(token.text, expansion)) + mwt_map[token.text] = expansion + return mwt_map + +def update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor): + """ + Replace MWT structures with their UD equivalents, along with some other minor tsurgeon based edits + + original_tree: the tree as read from VIT + dep_sentence: the UD dependency dataset version of this sentence + """ + updated_tree = original_tree + + operations = [] + + # first, remove titles or testo from the start of a sentence + con_words = updated_tree.leaf_labels() + if con_words[0] == "Tit'": + operations.append(["/^Tit'$/=prune !, __", "prune prune"]) + elif con_words[0] == "TESTO": + operations.append(["/^TESTO$/=prune !, __", "prune prune"]) + elif con_words[0] == "testo": + operations.append(["/^testo$/ !, __ . /^:$/=prune", "prune prune"]) + operations.append(["/^testo$/=prune !, __", "prune prune"]) + + if len(con_words) >= 2 and con_words[-2] == '...' and con_words[-1] == '.': + # the most recent VIT constituency has some sentence final . after a ... + # the UD dataset has a more typical ... ending instead + # these lines used to say "riempire" which was rather odd + operations.append(["/^[.][.][.]$/ . /^[.]$/=prune", "prune prune"]) + + # a few constituent tags are simply errors which need to be fixed + if original_tree.children[0].label == 'p': + # 'p' shouldn't be at root + operations.append(["_ROOT_ < p=p", "relabel p cp"]) + # fix one specific tree if it has an s_top in it + operations.append(["s_top=stop < (in=in < più=piu)", "replace piu (q più)", "relabel in sq", "relabel stop sa"]) + # sect doesn't exist as a constituent. replace it with sa + operations.append(["sect=sect < num", "relabel sect sa"]) + # ppas as an internal node gets removed + operations.append(["ppas=ppas < (__ < __)", "excise ppas ppas"]) + + # now assemble a bunch of regex to split and otherwise manipulate + # the MWT in the trees + for token in dep_sentence.tokens: + if len(token.words) == 1: + continue + if token.text in mwt_map: + mwt_pieces = mwt_map[token.text] + if len(mwt_pieces) != 2: + raise NotImplementedError("Expected exactly 2 pieces of mwt for %s" % token.text) + # the MWT words in the UD version will have ' when needed, + # but the corresponding ' is skipped in the con version of VIT, + # hence the replace("'", "") + # however, all' has the ' included, because this is a + # constituent treebank, not a consistent treebank + search_regex = "/^(?i:%s(?:')?)$/" % token.text.replace("'", "") + # tags which seem to be relevant: + # avvl|ccom|php|part|partd|partda + tregex = "__ !> __ <<<%d (%s=child > (__=parent $+ sn=sn))" % (token.id[0], search_regex) + tsurgeons = ["insert (art %s) >0 sn" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]] + operations.append([tregex] + tsurgeons) + + tregex = "__ !> __ <<<%d (%s=child > (__=parent !$+ sn !$+ (art < %s)))" % (token.id[0], search_regex, mwt_pieces[1]) + tsurgeons = ["insert (art %s) $- parent" % mwt_pieces[1], "relabel child %s" % mwt_pieces[0]] + operations.append([tregex] + tsurgeons) + elif len(token.words) == 2: + #print("{} not in mwt_map".format(token.text)) + # apparently some trees like sent_00381 and sent_05070 + # have the clitic in a non-projective manner + # [vcl-essersi, vppin-sparato, compt-[clitdat-si + # intj-figurarsi, fs-[cosu-quando, f-[ibar-[clit-si + # and before you ask, there are also clitics which are + # simply not there at all, rather than always attached + # in a non-projective manner + tregex = "__=parent < (/^(?i:%s)$/=child . (__=np !< __ . (/^clit/=clit < %s)))" % (token.text, token.words[1].text) + tsurgeon = "moveprune clit $- parent" + operations.append([tregex, tsurgeon]) + + # there are also some trees which don't have clitics + # for example, trees should look like this: + # [ibar-[vsup-poteva, vcl-rivelarsi], compc-[clit-si, sn-[...]]] + # however, at least one such example for rivelarsi instead + # looks like this, with no corresponding clit + # [... vcl-rivelarsi], compc-[sn-[in-ancora]] + # note that is the actual tag, not just me being pissed off + # breaking down the tregex: + # the child is the original MWT, not split + # !. clit verifies that it is not split (and stops the tsurgeon once fixed) + # !$+ checks that the parent of the MWT is the last element under parent + # note that !. can leave the immediate parent to touch the clit + # neighbor will be the place the new clit will be sticking out + tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) !$+ __ > (__=gp $+ __=neighbor)" % token.text + tsurgeon = "insert (clit %s) >0 neighbor" % token.words[1].text + operations.append([tregex, tsurgeon]) + + # secondary option: while most trees are like the above, + # with an outer bracket around the MWT and another verb, + # some go straight into the next phrase + # sent_05076 + # sv5-[vcl-adeguandosi, compin-[sp-[part-alle, ... + tregex = "__=parent < (/^(?i:%s)$/=child !. /^clit/) $+ __" % token.text + tsurgeon = "insert (clit %s) $- parent" % token.words[1].text + operations.append([tregex, tsurgeon]) + else: + pass + if len(operations) > 0: + updated_tree = tsurgeon_processor.process(updated_tree, *operations)[0] + return updated_tree, operations + +def update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor): + """ + Update a tree using the mwt_map and tsurgeon to expand some MWTs + + Then replace the words in the con tree with the words in the dep tree + """ + ud_words = [x.text for x in dep_sentence.words] + + updated_tree, operations = update_mwts_and_special_cases(original_tree, dep_sentence, mwt_map, tsurgeon_processor) + + # this checks number of words + try: + updated_tree = updated_tree.replace_words(ud_words) + except ValueError as e: + raise ValueError("Failed to process {} {}:\nORIGINAL TREE\n{}\nUPDATED TREE\n{}\nUPDATED LEAVES\n{}\nUD TEXT\n{}\nTsurgeons applied:\n{}\n".format(con_id, dep_id, original_tree, updated_tree, updated_tree.leaf_labels(), ud_words, "\n".join("{}".format(op) for op in operations))) from e + return updated_tree + +# train set: +# 858: missing close parens in the UD conversion +# 1169: 'che', 'poi', 'tutti', 'i', 'Paesi', 'ue', '.' -> 'per', 'tutti', 'i', 'paesi', 'Ue', '.' +# 2375: the problem is inconsistent treatment of s_p_a_ +# 05052: the heuristic to fill in a missing "si" doesn't work because there's +# already another "si" immediately after +# +# test set: +# 09764: weird punct at end +# 10058: weird punct at end +IGNORE_IDS = ["sent_00867", "sent_01169", "sent_02375", "sent_05052", "sent_09764", "sent_10058"] + +def extract_updated_dataset(con_tree_map, dep_sentence_map, split_ids, mwt_map, tsurgeon_processor): + """ + Update constituency trees using the information in the dependency treebank + """ + trees = [] + for con_id, dep_id in tqdm(split_ids.items()): + # skip a few trees which have non-MWT word modifications + if con_id in IGNORE_IDS: + continue + original_tree = con_tree_map[con_id] + dep_sentence = dep_sentence_map[dep_id] + updated_tree = update_tree(original_tree, dep_sentence, con_id, dep_id, mwt_map, tsurgeon_processor) + + trees.append(ProcessedTree(con_id, dep_id, updated_tree)) + return trees + +def read_updated_trees(paths, debug_sentence=None): + # original version with more errors + #con_filename = os.path.join(con_directory, "2011-12-20", "Archive", "VIT_newconstsynt.txt") + # this is the April 2022 version + #con_filename = os.path.join(con_directory, "VIT_newconstsynt.txt") + # the most recent update from ELRA may look like this? + # it's what we got, at least + # con_filename = os.path.join(con_directory, "italian", "VITwritten", "VITconstsyntNumb") + + # needs at least UD 2.11 or this will not work + con_directory = paths["CONSTITUENCY_BASE"] + ud_directory = os.path.join(paths["UDBASE"], "UD_Italian-VIT") + + con_filename = os.path.join(con_directory, "italian", "it_vit", "VITwritten", "VITconstsyntNumb") + ud_vit_train = os.path.join(ud_directory, "it_vit-ud-train.conllu") + ud_vit_dev = os.path.join(ud_directory, "it_vit-ud-dev.conllu") + ud_vit_test = os.path.join(ud_directory, "it_vit-ud-test.conllu") + + print("Reading UD train/dev/test from %s" % ud_directory) + ud_train_data = CoNLL.conll2doc(input_file=ud_vit_train) + ud_dev_data = CoNLL.conll2doc(input_file=ud_vit_dev) + ud_test_data = CoNLL.conll2doc(input_file=ud_vit_test) + + ud_vit_train_map = { DEP_ID_FUNC(x) : x for x in ud_train_data.sentences } + ud_vit_dev_map = { DEP_ID_FUNC(x) : x for x in ud_dev_data.sentences } + ud_vit_test_map = { DEP_ID_FUNC(x) : x for x in ud_test_data.sentences } + + print("Getting ADP/DET expansions from UD data") + mwt_map = get_mwt(ud_train_data, ud_dev_data, ud_test_data) + + con_sentences = read_constituency_file(con_filename) + num_discarded = 0 + con_tree_map = {} + for idx, sentence in enumerate(tqdm(con_sentences, postfix="Processing")): + try: + tree = raw_tree(sentence[1]) + if sentence[0].startswith("#ID="): + tree_id = sentence[0].split("=")[-1] + else: + tree_id = sentence[0].split("#")[-1] + # don't care about the raw text? + con_tree_map[tree_id] = tree + except UnclosedTreeError as e: + num_discarded = num_discarded + 1 + print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) + except ExtraCloseTreeError as e: + num_discarded = num_discarded + 1 + print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) + except ValueError as e: + print("Discarding {} because of reading error:\n {}: {}\n {}".format(sentence[0], type(e), e, sentence[1])) + num_discarded = num_discarded + 1 + #raise ValueError("Could not process line %d" % idx) from e + + print("Discarded %d trees. Have %d trees left" % (num_discarded, len(con_tree_map))) + if num_discarded > 0: + raise ValueError("Oops! We thought all of the VIT trees were properly bracketed now") + con_vit_ngrams = build_ngrams(con_tree_map.items(), lambda x: CON_PROCESS_FUNC(x[1]), lambda x: x[0]) + + # TODO: match more sentences. some are probably missing because of MWT + train_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_train_data.sentences, "train", debug_sentence) + dev_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_dev_data.sentences, "dev", debug_sentence) + test_ids = match_sentences(con_tree_map, con_vit_ngrams, ud_test_data.sentences, "test", debug_sentence) + print("Remaining total trees: %d" % (len(train_ids) + len(dev_ids) + len(test_ids))) + print(" {} train {} dev {} test".format(len(train_ids), len(dev_ids), len(test_ids))) + print("Updating trees with MWT and newer tokens from UD...") + + # the moveprune feature requires a new corenlp release after 4.4.0 + with tsurgeon.Tsurgeon(classpath="$CLASSPATH") as tsurgeon_processor: + train_trees = extract_updated_dataset(con_tree_map, ud_vit_train_map, train_ids, mwt_map, tsurgeon_processor) + dev_trees = extract_updated_dataset(con_tree_map, ud_vit_dev_map, dev_ids, mwt_map, tsurgeon_processor) + test_trees = extract_updated_dataset(con_tree_map, ud_vit_test_map, test_ids, mwt_map, tsurgeon_processor) + + return train_trees, dev_trees, test_trees + +def convert_it_vit(paths, dataset_name, debug_sentence=None): + """ + Read the trees, then write them out to the expected output_directory + """ + train_trees, dev_trees, test_trees = read_updated_trees(paths, debug_sentence) + + train_trees = [x.tree for x in train_trees] + dev_trees = [x.tree for x in dev_trees] + test_trees = [x.tree for x in test_trees] + + output_directory = paths["CONSTITUENCY_DATA_DIR"] + write_dataset([train_trees, dev_trees, test_trees], output_directory, dataset_name) + +def main(): + paths = default_paths.get_default_paths() + dataset_name = "it_vit" + + debug_sentence = sys.argv[1] if len(sys.argv) > 1 else None + + convert_it_vit(paths, dataset_name, debug_sentence) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/convert_spmrl.py b/stanza/stanza/utils/datasets/constituency/convert_spmrl.py new file mode 100644 index 0000000000000000000000000000000000000000..706061d8673aa3ca6739291fe45c9589b0b986ca --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_spmrl.py @@ -0,0 +1,35 @@ +import os + +from stanza.models.constituency.parse_tree import Tree +from stanza.models.constituency.tree_reader import read_treebank +from stanza.utils.default_paths import get_default_paths + +SHARDS = ("train", "dev", "test") + +def add_root(tree): + if tree.label.startswith("NN"): + tree = Tree("NP", tree) + if tree.label.startswith("NE"): + tree = Tree("PN", tree) + elif tree.label.startswith("XY"): + tree = Tree("VROOT", tree) + return Tree("ROOT", tree) + +def convert_spmrl(input_directory, output_directory, short_name): + for shard in SHARDS: + tree_filename = os.path.join(input_directory, shard, shard + ".German.gold.ptb") + trees = read_treebank(tree_filename, tree_callback=add_root) + output_filename = os.path.join(output_directory, "%s_%s.mrg" % (short_name, shard)) + with open(output_filename, "w", encoding="utf-8") as fout: + for tree in trees: + fout.write(str(tree)) + fout.write("\n") + print("Wrote %d trees to %s" % (len(trees), output_filename)) + +if __name__ == '__main__': + paths = get_default_paths() + output_directory = paths["CONSTITUENCY_DATA_DIR"] + input_directory = "extern_data/constituency/spmrl/SPMRL_SHARED_2014/GERMAN_SPMRL/gold/ptb" + convert_spmrl(input_directory, output_directory, "de_spmrl") + + diff --git a/stanza/stanza/utils/datasets/constituency/convert_starlang.py b/stanza/stanza/utils/datasets/constituency/convert_starlang.py new file mode 100644 index 0000000000000000000000000000000000000000..248762b5ab592651545826e15d8e2ad20c33bbfe --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/convert_starlang.py @@ -0,0 +1,96 @@ + +import os +import re + +from tqdm import tqdm + +from stanza.models.constituency import parse_tree +from stanza.models.constituency import tree_reader + +TURKISH_RE = re.compile(r"[{]turkish=([^}]+)[}]") + +DISALLOWED_LABELS = ('DT', 'DET', 's', 'vp', 'AFVP', 'CONJ', 'INTJ', '-XXX-') + +def read_tree(text): + """ + Reads in a tree, then extracts specifically the word from the specific format used + + Also converts LCB/RCB as needed + """ + trees = tree_reader.read_trees(text) + if len(trees) > 1: + raise ValueError("Tree file had two trees!") + tree = trees[0] + labels = tree.leaf_labels() + new_labels = [] + for label in labels: + match = TURKISH_RE.search(label) + if match is None: + raise ValueError("Could not find word in |{}|".format(label)) + word = match.group(1) + word = word.replace("-LCB-", "{").replace("-RCB-", "}") + new_labels.append(word) + + tree = tree.replace_words(new_labels) + #tree = tree.remap_constituent_labels(LABEL_MAP) + con_labels = tree.get_unique_constituent_labels([tree]) + if any(label in DISALLOWED_LABELS for label in con_labels): + raise ValueError("found an unexpected phrasal node {}".format(label)) + return tree + +def read_files(filenames, conversion, log): + trees = [] + for filename in filenames: + with open(filename, encoding="utf-8") as fin: + text = fin.read() + try: + tree = conversion(text) + if tree is not None: + trees.append(tree) + except ValueError as e: + if log: + print("-----------------\nFound an error in {}: {} Original text: {}".format(filename, e, text)) + return trees + +def read_starlang(paths, conversion=read_tree, log=True): + """ + Read the starlang trees, converting them using the given method. + + read_tree or any other conversion turns one file at a time to a sentence. + log is whether or not to log a ValueError - the NER division has many missing labels + """ + if isinstance(paths, str): + paths = (paths,) + + train_files = [] + dev_files = [] + test_files = [] + + for path in paths: + tree_files = [os.path.join(path, x) for x in os.listdir(path)] + train_files.extend([x for x in tree_files if x.endswith(".train")]) + dev_files.extend([x for x in tree_files if x.endswith(".dev")]) + test_files.extend([x for x in tree_files if x.endswith(".test")]) + + print("Reading %d total files" % (len(train_files) + len(dev_files) + len(test_files))) + train_treebank = read_files(tqdm(train_files), conversion=conversion, log=log) + dev_treebank = read_files(tqdm(dev_files), conversion=conversion, log=log) + test_treebank = read_files(tqdm(test_files), conversion=conversion, log=log) + + return train_treebank, dev_treebank, test_treebank + +def main(conversion=read_tree, log=True): + paths = ["extern_data/constituency/turkish/TurkishAnnotatedTreeBank-15", + "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-15", + "extern_data/constituency/turkish/TurkishAnnotatedTreeBank2-20"] + train_treebank, dev_treebank, test_treebank = read_starlang(paths, conversion=conversion, log=log) + + print("Train: %d" % len(train_treebank)) + print("Dev: %d" % len(dev_treebank)) + print("Test: %d" % len(test_treebank)) + + print(train_treebank[0]) + return train_treebank, dev_treebank, test_treebank + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/extract_all_silver_dataset.py b/stanza/stanza/utils/datasets/constituency/extract_all_silver_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..153d9a59c1badfb9478861bf921b1e40a5e01ce8 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/extract_all_silver_dataset.py @@ -0,0 +1,46 @@ +""" +After running build_silver_dataset.py, this extracts the trees of all match levels at once + +For example + +python stanza/utils/datasets/constituency/extract_all_silver_dataset.py --output_prefix /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_ --parsed_trees /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_wiki_a*trees + +cat /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_[012345678].mrg | sort | uniq | shuf > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg + +shuf /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_sort.mrg | head -n 200000 > /u/nlp/data/constituency-parser/chinese/2024_zh_wiki/zh_silver_200K.mrg +""" + +import argparse +from collections import defaultdict +import json + +def parse_args(): + parser = argparse.ArgumentParser(description="After finding common trees using build_silver_dataset, this extracts them all or just the ones from a particular level of accuracy") + parser.add_argument('--parsed_trees', type=str, nargs='+', help='Input file(s) of trees parsed into the build_silver_dataset json format.') + parser.add_argument('--output_prefix', type=str, default=None, help='Prefix to use for outputting trees') + parser.add_argument('--output_suffix', type=str, default=".mrg", help='Suffix to use for outputting trees') + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + trees = defaultdict(list) + for filename in args.parsed_trees: + with open(filename, encoding='utf-8') as fin: + for line in fin.readlines(): + tree = json.loads(line) + trees[tree['count']].append(tree['tree']) + + for score, tree_list in trees.items(): + filename = "%s%s%s" % (args.output_prefix, score, args.output_suffix) + with open(filename, 'w', encoding='utf-8') as fout: + for tree in tree_list: + fout.write(tree) + fout.write('\n') + +if __name__ == '__main__': + main() + + diff --git a/stanza/stanza/utils/datasets/constituency/relabel_tags.py b/stanza/stanza/utils/datasets/constituency/relabel_tags.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc207d6114d0e1716a41b8ba5924139bd52c4f7 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/relabel_tags.py @@ -0,0 +1,48 @@ +""" +Retag an S-expression tree with a new set of POS tags + +Also includes an option to write the new trees as bracket_labels +(essentially, skipping the treebank_to_labeled_brackets step) +""" + +import argparse +import logging + +from stanza import Pipeline +from stanza.models.constituency import retagging +from stanza.models.constituency import tree_reader +from stanza.models.constituency.utils import retag_trees + +logger = logging.getLogger('stanza') + +def parse_args(): + parser = argparse.ArgumentParser(description="Script that retags a tree file") + parser.add_argument('--lang', default='vi', type=str, help='Language') + parser.add_argument('--input_file', default='data/constituency/vi_vlsp21_train.mrg', help='File to retag') + parser.add_argument('--output_file', default='vi_vlsp21_train_retagged.mrg', help='Where to write the retagged trees') + retagging.add_retag_args(parser) + + parser.add_argument('--bracket_labels', action='store_true', help='Write the trees as bracket labels instead of S-expressions') + + args = parser.parse_args() + args = vars(args) + retagging.postprocess_args(args) + + return args + +def main(): + args = parse_args() + + retag_pipeline = retagging.build_retag_pipeline(args) + + train_trees = tree_reader.read_treebank(args['input_file']) + logger.info("Retagging %d trees using %s", len(train_trees), args['retag_package']) + train_trees = retag_trees(train_trees, retag_pipeline, args['retag_xpos']) + tree_format = "{:L}" if args['bracket_labels'] else "{}" + with open(args['output_file'], "w") as fout: + for tree in train_trees: + fout.write(tree_format.format(tree)) + fout.write("\n") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/selftrain.py b/stanza/stanza/utils/datasets/constituency/selftrain.py new file mode 100644 index 0000000000000000000000000000000000000000..b4798fbd873298c8f1e0ae07096ac2dbd450b44e --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/selftrain.py @@ -0,0 +1,268 @@ +""" +Common methods for the various self-training data collection scripts +""" + +import logging +import os +import random +import re + +import stanza +from stanza.models.common import utils +from stanza.models.common.bert_embedding import TextTooLongError +from stanza.utils.get_tqdm import get_tqdm + +logger = logging.getLogger('stanza') +tqdm = get_tqdm() + +def common_args(parser): + parser.add_argument( + '--output_file', + default='data/constituency/vi_silver.mrg', + help='Where to write the silver trees' + ) + parser.add_argument( + '--lang', + default='vi', + help='Which language tools to use for tokenization and POS' + ) + parser.add_argument( + '--num_sentences', + type=int, + default=-1, + help='How many sentences to get per file (max)' + ) + parser.add_argument( + '--models', + default='saved_models/constituency/vi_vlsp21_inorder.pt', + help='What models to use for parsing. comma-separated' + ) + parser.add_argument( + '--package', + default='default', + help='Which package to load pretrain & charlm from for the parsers' + ) + parser.add_argument( + '--output_ptb', + default=False, + action='store_true', + help='Output trees in PTB brackets (default is a bracket language format)' + ) + +def add_length_args(parser): + parser.add_argument( + '--min_len', + default=5, + type=int, + help='Minimum length sentence to keep. None = unlimited' + ) + parser.add_argument( + '--no_min_len', + dest='min_len', + action='store_const', + const=None, + help='No minimum length' + ) + parser.add_argument( + '--max_len', + default=100, + type=int, + help='Maximum length sentence to keep. None = unlimited' + ) + parser.add_argument( + '--no_max_len', + dest='max_len', + action='store_const', + const=None, + help='No maximum length' + ) + +def build_ssplit_pipe(ssplit, lang): + if ssplit: + return stanza.Pipeline(lang, processors="tokenize") + else: + return stanza.Pipeline(lang, processors="tokenize", tokenize_no_ssplit=True) + +def build_tag_pipe(ssplit, lang, foundation_cache=None): + if ssplit: + return stanza.Pipeline(lang, processors="tokenize,pos", foundation_cache=foundation_cache) + else: + return stanza.Pipeline(lang, processors="tokenize,pos", tokenize_no_ssplit=True, foundation_cache=foundation_cache) + +def build_parser_pipes(lang, models, package="default", foundation_cache=None): + """ + Build separate pipelines for each parser model we want to use + + It is highly recommended to pass in a FoundationCache to reuse bottom layers + """ + parser_pipes = [] + for model_name in models.split(","): + if os.path.exists(model_name): + # if the model name exists as a file, treat it as the path to the model + pipe = stanza.Pipeline(lang, processors="constituency", package=package, constituency_model_path=model_name, constituency_pretagged=True, foundation_cache=foundation_cache) + else: + # otherwise, assume it is a package name? + pipe = stanza.Pipeline(lang, processors={"constituency": model_name}, constituency_pretagged=True, package=None, foundation_cache=foundation_cache) + parser_pipes.append(pipe) + return parser_pipes + +def split_docs(docs, ssplit_pipe, max_len=140, max_word_len=50, chunk_size=2000): + """ + Using the ssplit pipeline, break up the documents into sentences + + Filters out sentences which are too long or have words too long. + + This step is necessary because some web text has unstructured + sentences which overwhelm the tagger, or even text with no + whitespace which breaks the charlm in the tokenizer or tagger + """ + raw_sentences = 0 + filtered_sentences = 0 + new_docs = [] + + logger.info("Splitting raw docs into sentences: %d", len(docs)) + for chunk_start in tqdm(range(0, len(docs), chunk_size)): + chunk = docs[chunk_start:chunk_start+chunk_size] + chunk = [stanza.Document([], text=t) for t in chunk] + chunk = ssplit_pipe(chunk) + sentences = [s for d in chunk for s in d.sentences] + raw_sentences += len(sentences) + sentences = [s for s in sentences if len(s.words) < max_len] + sentences = [s for s in sentences if max(len(w.text) for w in s.words) < max_word_len] + filtered_sentences += len(sentences) + new_docs.extend([s.text for s in sentences]) + + logger.info("Split sentences: %d", raw_sentences) + logger.info("Sentences filtered for length: %d", filtered_sentences) + return new_docs + +# from https://stackoverflow.com/questions/2718196/find-all-chinese-text-in-a-string-using-python-and-regex +ZH_RE = re.compile(u'[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]', re.UNICODE) +# https://stackoverflow.com/questions/6787716/regular-expression-for-japanese-characters +JA_RE = re.compile(u'[一-龠ぁ-ゔァ-ヴー々〆〤ヶ]', re.UNICODE) +DEV_RE = re.compile(u'[\u0900-\u097f]', re.UNICODE) + +def tokenize_docs(docs, pipe, min_len, max_len): + """ + Turn the text in docs into a list of whitespace separated sentences + + docs: a list of strings + pipe: a Stanza pipeline for tokenizing + min_len, max_len: can be None to not filter by this attribute + """ + results = [] + docs = [stanza.Document([], text=t) for t in docs] + if len(docs) == 0: + return results + pipe(docs) + is_zh = pipe.lang and pipe.lang.startswith("zh") + is_ja = pipe.lang and pipe.lang.startswith("ja") + is_vi = pipe.lang and pipe.lang.startswith("vi") + for doc in docs: + for sentence in doc.sentences: + if min_len and len(sentence.words) < min_len: + continue + if max_len and len(sentence.words) > max_len: + continue + text = sentence.text + if (text.find("|") >= 0 or text.find("_") >= 0 or + text.find("<") >= 0 or text.find(">") >= 0 or + text.find("[") >= 0 or text.find("]") >= 0 or + text.find('—') >= 0): # an em dash, seems to be part of lists + continue + # the VI tokenizer in particular doesn't split these well + if any(any(w.text.find(c) >= 0 and len(w.text) > 1 for w in sentence.words) + for c in '"()'): + continue + text = [w.text.replace(" ", "_") for w in sentence.words] + text = " ".join(text) + if any(len(w.text) >= 50 for w in sentence.words): + # skip sentences where some of the words are unreasonably long + # could make this an argument + continue + if not is_zh and len(ZH_RE.findall(text)) > 250: + # some Chinese sentences show up in VI Wikipedia + # we want to eliminate ones which will choke the bert models + continue + if not is_ja and len(JA_RE.findall(text)) > 150: + # some Japanese sentences also show up in VI Wikipedia + # we want to eliminate ones which will choke the bert models + continue + if is_vi and len(DEV_RE.findall(text)) > 100: + # would need some list of languages that use + # Devanagari to eliminate sentences from all datasets. + # Otherwise we might accidentally throw away all the + # text from a language we need (although that would be obvious) + continue + results.append(text) + return results + +def find_matching_trees(docs, num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=True, chunk_size=10, max_len=140, min_len=10, output_ptb=False): + """ + Find trees where all the parsers in parser_pipes agree + + docs should be a list of strings. + one sentence per string or a whole block of text as long as the tag_pipe can break it into sentences + + num_sentences > 0 gives an upper limit on how many sentences to extract. + If < 0, all possible sentences are extracted + + accepted_trees is a running tally of all the trees already built, + so that we don't reuse the same sentence if we see it again + """ + if num_sentences < 0: + tqdm_total = len(docs) + else: + tqdm_total = num_sentences + + if output_ptb: + output_format = "{}" + else: + output_format = "{:L}" + + with tqdm(total=tqdm_total, leave=False) as pbar: + if shuffle: + random.shuffle(docs) + new_trees = set() + for chunk_start in range(0, len(docs), chunk_size): + chunk = docs[chunk_start:chunk_start+chunk_size] + chunk = [stanza.Document([], text=t) for t in chunk] + + if num_sentences < 0: + pbar.update(len(chunk)) + + # first, retag the sentences + tag_pipe(chunk) + + chunk = [d for d in chunk if len(d.sentences) > 0] + if max_len is not None: + # for now, we don't have a good way to deal with sentences longer than the bert maxlen + chunk = [d for d in chunk if max(len(s.words) for s in d.sentences) < max_len] + if len(chunk) == 0: + continue + + parses = [] + try: + for pipe in parser_pipes: + pipe(chunk) + trees = [output_format.format(sent.constituency) for doc in chunk for sent in doc.sentences if len(sent.words) >= min_len] + parses.append(trees) + except TextTooLongError as e: + # easiest is to skip this chunk - could theoretically save the other sentences + continue + + for tree in zip(*parses): + if len(set(tree)) != 1: + continue + tree = tree[0] + if tree in accepted_trees: + continue + if tree not in new_trees: + new_trees.add(tree) + if num_sentences >= 0: + pbar.update(1) + if num_sentences >= 0 and len(new_trees) >= num_sentences: + return new_trees + + return new_trees + diff --git a/stanza/stanza/utils/datasets/constituency/selftrain_it.py b/stanza/stanza/utils/datasets/constituency/selftrain_it.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6a608f8102d44690f1e3d0a767ff054e97ffcd --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/selftrain_it.py @@ -0,0 +1,120 @@ +"""Builds a self-training dataset from an Italian data source and two models + +The idea is that the top down and the inorder parsers should make +somewhat different errors, so hopefully the sum of an 86 f1 parser and +an 85.5 f1 parser will produce some half-decent silver trees which can +be used as self-training so that a new model can do better than either. + +One dataset used is PaCCSS, which has 63000 pairs of sentences: + +http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/ + +PaCCSS-IT: A Parallel Corpus of Complex-Simple Sentences for Automatic Text Simplification + Brunato, Dominique et al, 2016 + https://aclanthology.org/D16-1034 + +Even larger is the IT section of Europarl, which has 1900000 lines + +https://www.statmt.org/europarl/ + +Europarl: A Parallel Corpus for Statistical Machine Translation + Philipp Koehn + https://homepages.inf.ed.ac.uk/pkoehn/publications/europarl-mtsummit05.pdf +""" + +import argparse +import logging +import os +import random + +import stanza +from stanza.models.common.foundation_cache import FoundationCache +from stanza.utils.datasets.constituency import selftrain +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() +logger = logging.getLogger('stanza') + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script that converts a public IT dataset to silver standard trees" + ) + selftrain.common_args(parser) + parser.add_argument( + '--input_dir', + default='extern_data/italian', + help='Path to the PaCCSS corpus and europarl corpus' + ) + + parser.add_argument( + '--no_europarl', + default=True, + action='store_false', + dest='europarl', + help='Use the europarl dataset. Turning this off makes the script a lot faster' + ) + + parser.set_defaults(lang="it") + parser.set_defaults(package="vit") + parser.set_defaults(models="saved_models/constituency/it_best/it_vit_inorder_best.pt,saved_models/constituency/it_best/it_vit_topdown.pt") + parser.set_defaults(output_file="data/constituency/it_silver.mrg") + + args = parser.parse_args() + return args + +def get_paccss(input_dir): + """ + Read the paccss dataset, which is two sentences per line + """ + input_file = os.path.join(input_dir, "PaCCSS/data-set/PACCSS-IT.txt") + with open(input_file) as fin: + # the first line is a header line + lines = fin.readlines()[1:] + lines = [x.strip() for x in lines] + lines = [x.split("\t")[:2] for x in lines if x] + text = [y for x in lines for y in x] + logger.info("Read %d sentences from %s", len(text), input_file) + return text + +def get_europarl(input_dir, ssplit_pipe): + """ + Read the Europarl dataset + + This dataset needs to be tokenized and split into lines + """ + input_file = os.path.join(input_dir, "europarl/europarl-v7.it-en.it") + with open(input_file) as fin: + # the first line is a header line + lines = fin.readlines()[1:] + lines = [x.strip() for x in lines] + lines = [x for x in lines if x] + logger.info("Read %d docs from %s", len(lines), input_file) + lines = selftrain.split_docs(lines, ssplit_pipe) + return lines + +def main(): + """ + Combine the two datasets, parse them, and write out the results + """ + args = parse_args() + + foundation_cache = FoundationCache() + ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang) + tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang, foundation_cache=foundation_cache) + parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, package=args.package, foundation_cache=foundation_cache) + + docs = get_paccss(args.input_dir) + if args.europarl: + docs.extend(get_europarl(args.input_dir, ssplit_pipe)) + + logger.info("Processing %d docs", len(docs)) + new_trees = selftrain.find_matching_trees(docs, args.num_sentences, set(), tag_pipe, parser_pipes, shuffle=False, chunk_size=100, output_ptb=args.output_ptb) + logger.info("Found %d unique trees which are the same between models" % len(new_trees)) + with open(args.output_file, "w") as fout: + for tree in sorted(new_trees): + fout.write(tree) + fout.write("\n") + + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/selftrain_single_file.py b/stanza/stanza/utils/datasets/constituency/selftrain_single_file.py new file mode 100644 index 0000000000000000000000000000000000000000..2782655f7759654b5c0eff7b6f52ae28b89446ce --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/selftrain_single_file.py @@ -0,0 +1,88 @@ +""" +Builds a self-training dataset from a single file. + +Default is to assume one document of text per line. If a line has +multiple sentences, they will be split using the stanza tokenizer. +""" + +import argparse +import io +import logging +import os + +import numpy as np + +import stanza +from stanza.utils.datasets.constituency import selftrain +from stanza.utils.get_tqdm import get_tqdm + +logger = logging.getLogger('stanza') +tqdm = get_tqdm() + +def parse_args(): + """ + Only specific argument for this script is the file to process + """ + parser = argparse.ArgumentParser( + description="Script that converts a single file of text to silver standard trees" + ) + selftrain.common_args(parser) + parser.add_argument( + '--input_file', + default="vi_part_1.aa", + help='Path to the file to read' + ) + + args = parser.parse_args() + return args + + +def read_file(input_file): + """ + Read lines from an input file + + Takes care to avoid encoding errors at the end of Oscar files. + The Oscar splits sometimes break a utf-8 character in half. + """ + with open(input_file, "rb") as fin: + text = fin.read() + text = text.decode("utf-8", errors="replace") + with io.StringIO(text) as fin: + lines = fin.readlines() + return lines + + +def main(): + args = parse_args() + + # TODO: make ssplit an argument + ssplit_pipe = selftrain.build_ssplit_pipe(ssplit=True, lang=args.lang) + tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) + parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) + + # create a blank file. we will append to this file so that partial results can be used + with open(args.output_file, "w") as fout: + pass + + docs = read_file(args.input_file) + logger.info("Read %d lines from %s", len(docs), args.input_file) + docs = selftrain.split_docs(docs, ssplit_pipe) + + # breaking into chunks lets us output partial results and see the + # progress in log files + accepted_trees = set() + if len(docs) > 10000: + chunks = tqdm(np.array_split(docs, 100), disable=False) + else: + chunks = [docs] + for chunk in chunks: + new_trees = selftrain.find_matching_trees(chunk, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) + accepted_trees.update(new_trees) + + with open(args.output_file, "a") as fout: + for tree in sorted(new_trees): + fout.write(tree) + fout.write("\n") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/selftrain_vi_quad.py b/stanza/stanza/utils/datasets/constituency/selftrain_vi_quad.py new file mode 100644 index 0000000000000000000000000000000000000000..3426dcb453f6d05277b42b2d6bed2b523ab7d4bd --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/selftrain_vi_quad.py @@ -0,0 +1,98 @@ +""" +Processes the train section of VI QuAD into trees suitable for use in the conparser lm +""" + +import argparse +import json +import logging + +import stanza +from stanza.utils.datasets.constituency import selftrain + +logger = logging.getLogger('stanza') + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script that converts vi quad to silver standard trees" + ) + selftrain.common_args(parser) + selftrain.add_length_args(parser) + parser.add_argument( + '--input_file', + default="extern_data/vietnamese/ViQuAD/train_ViQuAD.json", + help='Path to the ViQuAD train file' + ) + parser.add_argument( + '--tokenize_only', + default=False, + action='store_true', + help='Tokenize instead of writing trees' + ) + + args = parser.parse_args() + return args + +def parse_quad(text): + """ + Read in a file from the VI quad dataset + + The train file has a specific format: + the doc has a 'data' section + each block in the data is a separate document (138 in the train file, for example) + each block has a 'paragraphs' section + each paragrah has 'qas' and 'context'. we care about the qas + each piece of qas has 'question', which is what we actually want + """ + doc = json.loads(text) + + questions = [] + + for block in doc['data']: + paragraphs = block['paragraphs'] + for paragraph in paragraphs: + qas = paragraph['qas'] + for question in qas: + questions.append(question['question']) + + return questions + + +def read_quad(train_file): + with open(train_file) as fin: + text = fin.read() + + return parse_quad(text) + +def main(): + """ + Turn the train section of VI quad into a list of trees + """ + args = parse_args() + + docs = read_quad(args.input_file) + logger.info("Read %d lines from %s", len(docs), args.input_file) + if args.tokenize_only: + pipe = stanza.Pipeline(args.lang, processors="tokenize") + text = selftrain.tokenize_docs(docs, pipe, args.min_len, args.max_len) + with open(args.output_file, "w", encoding="utf-8") as fout: + for line in text: + fout.write(line) + fout.write("\n") + else: + tag_pipe = selftrain.build_tag_pipe(ssplit=False, lang=args.lang) + parser_pipes = selftrain.build_parser_pipes(args.lang, args.models) + + # create a blank file. we will append to this file so that partial results can be used + with open(args.output_file, "w") as fout: + pass + + accepted_trees = set() + new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=False, chunk_size=100) + new_trees = [tree for tree in new_trees if tree.find("(_SQ") >= 0] + with open(args.output_file, "a") as fout: + for tree in sorted(new_trees): + fout.write(tree) + fout.write("\n") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/selftrain_wiki.py b/stanza/stanza/utils/datasets/constituency/selftrain_wiki.py new file mode 100644 index 0000000000000000000000000000000000000000..1a2b5a3612e4ea37f3846b8b6e41ddc4d14c79c3 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/selftrain_wiki.py @@ -0,0 +1,140 @@ +"""Builds a self-training dataset from an Italian data source and two models + +The idea is that the top down and the inorder parsers should make +somewhat different errors, so hopefully the sum of an 86 f1 parser and +an 85.5 f1 parser will produce some half-decent silver trees which can +be used as self-training so that a new model can do better than either. + +The dataset used is PaCCSS, which has 63000 pairs of sentences: + +http://www.italianlp.it/resources/paccss-it-parallel-corpus-of-complex-simple-sentences-for-italian/ +""" + +import argparse +from collections import deque +import glob +import os +import random + +from stanza.models.common.foundation_cache import FoundationCache +from stanza.utils.datasets.constituency import selftrain +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script that converts part of a wikipedia dump to silver standard trees" + ) + selftrain.common_args(parser) + parser.add_argument( + '--input_dir', + default='extern_data/vietnamese/wikipedia/text', + help='Path to the wikipedia dump after processing by wikiextractor' + ) + parser.add_argument( + '--no_shuffle', + dest='shuffle', + action='store_false', + help="Don't shuffle files when processing the directory" + ) + + parser.set_defaults(num_sentences=10000) + + args = parser.parse_args() + return args + +def list_wikipedia_files(input_dir): + """ + Get a list of wiki files under the input_dir + + Recursively traverse the directory, then sort + """ + if not os.path.isdir(input_dir) and os.path.split(input_dir)[1].startswith("wiki_"): + return [input_dir] + + wiki_files = [] + + recursive_files = deque() + recursive_files.extend(glob.glob(os.path.join(input_dir, "*"))) + while len(recursive_files) > 0: + next_file = recursive_files.pop() + if os.path.isdir(next_file): + recursive_files.extend(glob.glob(os.path.join(next_file, "*"))) + elif os.path.split(next_file)[1].startswith("wiki_"): + wiki_files.append(next_file) + + wiki_files.sort() + return wiki_files + +def read_wiki_file(filename): + """ + Read the text from a wiki file as a list of paragraphs. + + Each is its own item in the list. + Lines are separated by \n\n to give hints to the stanza tokenizer. + The first line after is skipped as it is usually the document title. + """ + with open(filename) as fin: + lines = fin.readlines() + docs = [] + current_doc = [] + line_iterator = iter(lines) + line = next(line_iterator, None) + while line is not None: + if line.startswith(" 2: + # a lot of very short documents are links to related documents + # a single wikipedia can have tens of thousands of useless almost-duplicates + docs.append("\n\n".join(current_doc)) + current_doc = [] + else: + # not the start or end of a doc + # hopefully this is valid text + line = line.replace("()", " ") + line = line.replace("( )", " ") + line = line.strip() + if line.find("<") >= 0 or line.find(">") >= 0: + line = "" + if line: + current_doc.append(line) + line = next(line_iterator, None) + + if current_doc: + docs.append("\n\n".join(current_doc)) + return docs + +def main(): + args = parse_args() + + random.seed(1234) + + wiki_files = list_wikipedia_files(args.input_dir) + if args.shuffle: + random.shuffle(wiki_files) + + foundation_cache = FoundationCache() + tag_pipe = selftrain.build_tag_pipe(ssplit=True, lang=args.lang, foundation_cache=foundation_cache) + parser_pipes = selftrain.build_parser_pipes(args.lang, args.models, foundation_cache=foundation_cache) + + # create a blank file. we will append to this file so that partial results can be used + with open(args.output_file, "w") as fout: + pass + + accepted_trees = set() + for filename in tqdm(wiki_files, disable=False): + docs = read_wiki_file(filename) + new_trees = selftrain.find_matching_trees(docs, args.num_sentences, accepted_trees, tag_pipe, parser_pipes, shuffle=args.shuffle) + accepted_trees.update(new_trees) + + with open(args.output_file, "a") as fout: + for tree in sorted(new_trees): + fout.write(tree) + fout.write("\n") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/split_holdout.py b/stanza/stanza/utils/datasets/constituency/split_holdout.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbfc7992b6d7caa3e6d9b5aecee046d7a65191c --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/split_holdout.py @@ -0,0 +1,64 @@ +""" +Split a constituency dataset randomly into 90/10 splits + +TODO: add a function to rotate the pieces of the split so that each +training instance gets seen once +""" + +import argparse +import os +import random + +from stanza.models.constituency import tree_reader +from stanza.utils.datasets.constituency.utils import copy_dev_test +from stanza.utils.default_paths import get_default_paths + +def write_trees(base_path, dataset_name, trees): + output_path = os.path.join(base_path, "%s_train.mrg" % dataset_name) + with open(output_path, "w", encoding="utf-8") as fout: + for tree in trees: + fout.write("%s\n" % tree) + + +def main(): + parser = argparse.ArgumentParser(description="Split a standard dataset into 90/10 proportions of train so there is held out training data") + parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split') + parser.add_argument('--base_dataset', type=str, default=None, help='output name for base dataset') + parser.add_argument('--holdout_dataset', type=str, default=None, help='output name for holdout dataset') + parser.add_argument('--ratio', type=float, default=0.1, help='Number of trees to hold out') + parser.add_argument('--seed', type=int, default=1234, help='Random seed') + args = parser.parse_args() + + if args.base_dataset is None: + args.base_dataset = args.dataset + "-base" + print("--base_dataset not set, using %s" % args.base_dataset) + + if args.holdout_dataset is None: + args.holdout_dataset = args.dataset + "-holdout" + print("--holdout_dataset not set, using %s" % args.holdout_dataset) + + base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"] + copy_dev_test(base_path, args.dataset, args.base_dataset) + copy_dev_test(base_path, args.dataset, args.holdout_dataset) + + train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset) + print("Reading %s" % train_file) + trees = tree_reader.read_tree_file(train_file) + + base_train = [] + holdout_train = [] + + random.seed(args.seed) + + for tree in trees: + if random.random() < args.ratio: + holdout_train.append(tree) + else: + base_train.append(tree) + + write_trees(base_path, args.base_dataset, base_train) + write_trees(base_path, args.holdout_dataset, holdout_train) + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/constituency/split_weighted_ensemble.py b/stanza/stanza/utils/datasets/constituency/split_weighted_ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..af77d2a2ea218fe5c8efdf44c36384436f321bcf --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/split_weighted_ensemble.py @@ -0,0 +1,73 @@ +""" +Read in a dataset and split the train portion into pieces + +One chunk of the train will be the original dataset. + +Others will be a sampling from the original dataset of the same size, +but sampled with replacement, with the goal being to get a random +distribution of trees with some reweighting of the original trees. +""" + +import argparse +import os +import random + +from stanza.models.constituency import tree_reader +from stanza.models.constituency.parse_tree import Tree +from stanza.utils.datasets.constituency.utils import copy_dev_test +from stanza.utils.default_paths import get_default_paths + +def main(): + parser = argparse.ArgumentParser(description="Split a standard dataset into 1 base section and N-1 random redraws of training data") + parser.add_argument('--dataset', type=str, default="id_icon", help='dataset to split') + parser.add_argument('--seed', type=int, default=1234, help='Random seed') + parser.add_argument('--num_splits', type=int, default=5, help='Number of splits') + args = parser.parse_args() + + random.seed(args.seed) + + base_path = get_default_paths()["CONSTITUENCY_DATA_DIR"] + train_file = os.path.join(base_path, "%s_train.mrg" % args.dataset) + print("Reading %s" % train_file) + train_trees = tree_reader.read_tree_file(train_file) + + # For datasets with low numbers of certain constituents in the train set, + # we could easily find ourselves in a situation where all of the trees + # with a specific constituent have been randomly shuffled away from + # a random shuffle + # An example of this is there are 3 total trees with SQ in id_icon + # Therefore, we have to take a little care to guarantee at least one tree + # for each constituent type is in a random slice + # TODO: this doesn't compensate for transition schemes with compound transitions, + # such as in_order_compound. could do a similar boosting with one per transition type + constituents = sorted(Tree.get_unique_constituent_labels(train_trees)) + con_to_trees = {con: list() for con in constituents} + for tree in train_trees: + tree_cons = Tree.get_unique_constituent_labels(tree) + for con in tree_cons: + con_to_trees[con].append(tree) + for con in constituents: + print("%d trees with %s" % (len(con_to_trees[con]), con)) + + for i in range(args.num_splits): + dataset_name = "%s-random-%d" % (args.dataset, i) + + copy_dev_test(base_path, args.dataset, dataset_name) + if i == 0: + train_dataset = train_trees + else: + train_dataset = [] + for con in constituents: + train_dataset.extend(random.choices(con_to_trees[con], k=2)) + needed_trees = len(train_trees) - len(train_dataset) + if needed_trees > 0: + print("%d trees already chosen. Adding %d more" % (len(train_dataset), needed_trees)) + train_dataset.extend(random.choices(train_trees, k=needed_trees)) + output_filename = os.path.join(base_path, "%s_train.mrg" % dataset_name) + print("Writing {} trees to {}".format(len(train_dataset), output_filename)) + Tree.write_treebank(train_dataset, output_filename) + + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/constituency/tokenize_wiki.py b/stanza/stanza/utils/datasets/constituency/tokenize_wiki.py new file mode 100644 index 0000000000000000000000000000000000000000..afbc8e4e2c0b4a216b21145428a04ea7407de602 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/tokenize_wiki.py @@ -0,0 +1,104 @@ +""" +A short script to use a Stanza tokenizer to extract tokenized sentences from Wikipedia + +The first step is to convert a Wikipedia dataset using Prof. Attardi's wikiextractor: +https://github.com/attardi/wikiextractor + +This script then writes out sentences, one per line, whitespace separated +Some common issues with the tokenizer are accounted for by discarding those lines. + +Also, to account for languages such as VI where whitespace occurs within words, +spaces are replaced with _ This should not cause any confusion, as any line with +a natural _ in has already been discarded. + +for i in `echo A B C D E F G H I J K`; do nlprun "python3 stanza/utils/datasets/constituency/tokenize_wiki.py --output_file /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.txt --lang it --max_len 120 --input_dir /u/nlp/data/Wikipedia/itwiki/B$i --tokenizer_model saved_models/tokenize/it_combined_tokenizer.pt --download_method None" -o /u/nlp/data/constituency-parser/italian/2024_wiki_tokenization/it_wiki_tokenized_B$i.out; done +""" + +import argparse +import logging + +import stanza +from stanza.models.common.bert_embedding import load_tokenizer, filter_data +from stanza.utils.datasets.constituency import selftrain_wiki +from stanza.utils.datasets.constituency.selftrain import add_length_args, tokenize_docs +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script that converts part of a wikipedia dump to silver standard trees" + ) + parser.add_argument( + '--output_file', + default='vi_wiki_tokenized.txt', + help='Where to write the tokenized lines' + ) + parser.add_argument( + '--lang', + default='vi', + help='Which language tools to use for tokenization and POS' + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + '--input_dir', + default=None, + help='Path to the wikipedia dump after processing by wikiextractor' + ) + input_group.add_argument( + '--input_file', + default=None, + help='Path to a single file of the wikipedia dump after processing by wikiextractor' + ) + parser.add_argument( + '--bert_tokenizer', + default=None, + help='Which bert tokenizer (if any) to use to filter long sentences' + ) + parser.add_argument( + '--tokenizer_model', + default=None, + help='Use this model instead of the current Stanza tokenizer for this language' + ) + parser.add_argument( + '--download_method', + default=None, + help='Download pipeline models using this method (defaults to downloading updates from HF)' + ) + add_length_args(parser) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + if args.input_dir is not None: + files = selftrain_wiki.list_wikipedia_files(args.input_dir) + elif args.input_file is not None: + files = [args.input_file] + else: + raise ValueError("Need to specify at least one file or directory!") + + if args.bert_tokenizer: + tokenizer = load_tokenizer(args.bert_tokenizer) + print("Max model length: %d" % tokenizer.model_max_length) + pipeline_args = {} + if args.tokenizer_model: + pipeline_args["tokenize_model_path"] = args.tokenizer_model + if args.download_method: + pipeline_args["download_method"] = args.download_method + pipe = stanza.Pipeline(args.lang, processors="tokenize", **pipeline_args) + + with open(args.output_file, "w", encoding="utf-8") as fout: + for filename in tqdm(files): + docs = selftrain_wiki.read_wiki_file(filename) + text = tokenize_docs(docs, pipe, args.min_len, args.max_len) + if args.bert_tokenizer: + filtered = filter_data(args.bert_tokenizer, [x.split() for x in text], tokenizer, logging.DEBUG) + text = [" ".join(x) for x in filtered] + for line in text: + fout.write(line) + fout.write("\n") + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/treebank_to_labeled_brackets.py b/stanza/stanza/utils/datasets/constituency/treebank_to_labeled_brackets.py new file mode 100644 index 0000000000000000000000000000000000000000..af15933d7196121097dc38d78819fac58830709a --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/treebank_to_labeled_brackets.py @@ -0,0 +1,55 @@ +""" +Converts a PTB file to a format where all the brackets have labels on the start and end bracket. + +Such a file should be suitable for training an LM +""" + +import argparse +import logging +import sys + +from stanza.models.constituency import tree_reader +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +logger = logging.getLogger('stanza.constituency') + +def main(): + parser = argparse.ArgumentParser( + description="Script that converts a PTB treebank into a labeled bracketed file suitable for LM training" + ) + + parser.add_argument( + 'ptb_file', + help='Where to get the original PTB format treebank' + ) + parser.add_argument( + 'label_file', + help='Where to write the labeled bracketed file' + ) + parser.add_argument( + '--separator', + default="_", + help='What separator to use in place of spaces', + ) + parser.add_argument( + '--no_separator', + dest='separator', + action='store_const', + const=None, + help="Don't use a separator" + ) + + args = parser.parse_args() + + treebank = tree_reader.read_treebank(args.ptb_file) + logger.info("Writing %d trees to %s", len(treebank), args.label_file) + + tree_format = "{:%sL}\n" % args.separator if args.separator else "{:L}\n" + with open(args.label_file, "w", encoding="utf-8") as fout: + for tree in tqdm(treebank): + fout.write(tree_format.format(tree)) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/utils.py b/stanza/stanza/utils/datasets/constituency/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b84dbbe18001772fad724e41bffb5bce4c959aa --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/utils.py @@ -0,0 +1,30 @@ +""" +Utilities for the processing of constituency treebanks +""" + +import os +import shutil + +from stanza.models.constituency import parse_tree + +SHARDS = ("train", "dev", "test") + +def copy_dev_test(base_path, input_dataset, output_dataset): + shutil.copy2(os.path.join(base_path, "%s_dev.mrg" % input_dataset), + os.path.join(base_path, "%s_dev.mrg" % output_dataset)) + shutil.copy2(os.path.join(base_path, "%s_test.mrg" % input_dataset), + os.path.join(base_path, "%s_test.mrg" % output_dataset)) + +def write_dataset(datasets, output_dir, dataset_name): + for dataset, shard in zip(datasets, SHARDS): + output_filename = os.path.join(output_dir, "%s_%s.mrg" % (dataset_name, shard)) + print("Writing {} trees to {}".format(len(dataset), output_filename)) + parse_tree.Tree.write_treebank(dataset, output_filename) + +def split_treebank(treebank, train_size, dev_size): + """ + Split a treebank deterministically + """ + train_end = int(len(treebank) * train_size) + dev_end = int(len(treebank) * (train_size + dev_size)) + return treebank[:train_end], treebank[train_end:dev_end], treebank[dev_end:] diff --git a/stanza/stanza/utils/datasets/constituency/vtb_convert.py b/stanza/stanza/utils/datasets/constituency/vtb_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..142369b8f2c8cff6da24e9d6e49cfc08b1b02338 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/vtb_convert.py @@ -0,0 +1,271 @@ +""" +Script for processing the VTB files and turning their trees into the desired tree syntax + +The VTB original trees are stored in the directory: +VietTreebank_VLSP_SP73/Kho ngu lieu 10000 cay cu phap +The script requires two arguments: +1. Original directory storing the original trees +2. New directory storing the converted trees +""" + +import argparse +import os + +from collections import defaultdict + +from stanza.models.constituency.tree_reader import read_trees, MixedTreeError, UnlabeledTreeError + +REMAPPING = { + '(ADV-MDP': '(RP-MDP', + '(MPD': '(MDP', + '(MP ': '(NP ', + '(MP(': '(NP(', + '(Np(': '(NP(', + '(Np (': '(NP (', + '(NLOC': '(NP-LOC', + '(N-P-LOC': '(NP-LOC', + '(N-p-loc': '(NP-LOC', + '(NPDOB': '(NP-DOB', + '(NPSUB': '(NP-SUB', + '(NPTMP': '(NP-TMP', + '(PPLOC': '(PP-LOC', + '(SBA ': '(SBAR ', + '(SBA-': '(SBAR-', + '(SBA(': '(SBAR(', + '(SBAS': '(SBAR', + '(SABR': '(SBAR', + '(SE-SPL': '(S-SPL', + '(SBARR': '(SBAR', + 'PPADV': 'PP-ADV', + '(PR (': '(PP (', + '(PPP': '(PP', + 'VP0ADV': 'VP-ADV', + '(S1': '(S', + '(S2': '(S', + '(S3': '(S', + 'BP-SUB': 'NP-SUB', + 'APPPD': 'AP-PPD', + 'APPRD': 'AP-PPD', + 'Np--H': 'Np-H', + '(WPNP': '(WHNP', + '(WHRPP': '(WHRP', + # the one mistagged PV is on a prepositional phrase + # (the subtree there maybe needs an SBAR as well, but who's counting) + '(PV': '(PP', + '(Mpd': '(MDP', + # this only occurs on "bao giờ", "when" + # that seems to be WHNP when under an SBAR, but WHRP otherwise + '(Whadv ': '(WHRP ', + # Whpr Occurs in two places: on "sao" in a context which is always WHRP, + # and on "nào", which Vy says is more like a preposition + '(Whpr (Pro-h nào))': '(WHPP (Pro-h nào))', + '(Whpr (Pro-h Sao))': '(WHRP (Pro-h Sao))', + # This is very clearly an NP: (Tp-tmp (N-h hiện nay)) + # which is only ever in NP-TMP contexts + '(Tp-tmp': '(NP-TMP', + # This occurs once, in the context of (Yp (SYM @)) + # The other times (SYM @) shows up, it's always NP + '(Yp': '(NP', +} + +def unify_label(tree): + for old, new in REMAPPING.items(): + tree = tree.replace(old, new) + + return tree + + +def count_paren_parity(tree): + """ + Checks if the tree is properly closed + :param tree: tree as a string + :return: True if closed otherwise False + """ + count = 0 + for char in tree: + if char == '(': + count += 1 + elif char == ')': + count -= 1 + return count + + +def is_valid_line(line): + """ + Check if a line being read is a valid constituent + + The idea is that some "trees" are just a long list of words with + no tree structure and need to be eliminated. + + :param line: constituent being read + :return: True if it has open OR closing parenthesis. + """ + if line.startswith('(') or line.endswith(')'): + return True + + return False + +# not clear if TP is supposed to be NP or PP - needs a native speaker to decode +WEIRD_LABELS = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "(TP", "Xp", "XP", "WHVP", "WHPR", "NO", "WHADV", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys()))) +# the 2023 dataset has TP and WHADV as actual labels +# furthermore, trees with NO were cleaned up and one of the test trees has NORD as a word +WEIRD_LABELS_2023 = sorted(set(["WP", "YP", "SNP", "STC", "UPC", "Xp", "XP", "WHVP", "WHPR", "(SC (", "(VOC (", "(Adv (", "(SP (", "ADV-MDP", "(SPL", "(ADV (", "(V-MWE ("] + list(REMAPPING.keys()))) + +def convert_file(orig_file, new_file, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False): + """ + :param orig_file: original directory storing original trees + :param new_file: new directory storing formatted constituency trees + This function writes new trees to the corresponding files in new_file + """ + if updated_tagset: + weird_labels = WEIRD_LABELS_2023 + else: + weird_labels = WEIRD_LABELS + errors = defaultdict(list) + with open(orig_file, 'r', encoding='utf-8') as reader, open(new_file, 'w', encoding='utf-8') as writer: + content = reader.readlines() + # Tree string will only be written if the currently read + # tree is a valid tree. It will not be written if it + # does not have a '(' that signifies the presence of constituents + tree = "" + tree_id = None + reading_tree = False + for line_idx, line in enumerate(content): + line = ' '.join(line.split()) + if line == '': + continue + elif line == '' or line.startswith("") + tree_id = int(tree_id[:-1]) + elif line == '' and reading_tree: + # one tree in 25432.prd is not valid because + # it is just a bunch of blank lines + if tree.strip() == '(ROOT': + errors["empty"].append("Empty tree in {} line {}".format(orig_file, line_idx)) + continue + tree += ')\n' + parity = count_paren_parity(tree) + if parity > 0: + errors["unclosed"].append("Unclosed tree from {} line {}: |{}|".format(orig_file, line_idx, tree)) + continue + if parity < 0: + errors["extra_parens"].append("Extra parens at end of tree from {} line {} for having extra parens: {}".format(orig_file, line_idx, tree)) + continue + if convert_brackets: + tree = tree.replace("RBKT", "-RRB-").replace("LBKT", "-LRB-") + try: + # test that the tree can be read in properly + processed_trees = read_trees(tree) + if len(processed_trees) > 1: + errors["multiple"].append("Multiple trees in one xml annotation from {} line {}".format(orig_file, line_idx)) + continue + if len(processed_trees) == 0: + errors["empty"].append("Empty tree in {} line {}".format(orig_file, line_idx)) + continue + if not processed_trees[0].all_leaves_are_preterminals(): + errors["untagged_leaf"].append("Tree with non-preterminal leaves in {} line {}: {}".format(orig_file, line_idx, tree)) + continue + # Unify the labels + if fix_errors: + tree = unify_label(tree) + + # TODO: this block eliminates 3 trees from VLSP-22 + # maybe those trees can be salvaged? + bad_label = False + for weird_label in weird_labels: + if tree.find(weird_label) >= 0: + bad_label = True + errors[weird_label].append("Weird label {} from {} line {}: {}".format(weird_label, orig_file, line_idx, tree)) + break + if bad_label: + continue + + if write_ids: + if tree_id is None: + errors["missing_id"].append("Missing ID from {} at line {}".format(orig_file, line_idx)) + writer.write("") + else: + writer.write("\n" % tree_id) + writer.write(tree) + if write_ids: + writer.write("\n") + reading_tree = False + tree = "" + tree_id = None + except MixedTreeError: + errors["mixed"].append("Mixed leaves and constituents from {} line {}: {}".format(orig_file, line_idx, tree)) + except UnlabeledTreeError: + errors["unlabeled"].append("Unlabeled nodes in tree from {} line {}: {}".format(orig_file, line_idx, tree)) + else: # content line + if is_valid_line(line) and reading_tree: + tree += line + elif reading_tree: + errors["invalid"].append("Invalid tree error in {} line {}: |{}|, rejected because of line |{}|".format(orig_file, line_idx, tree, line)) + reading_tree = False + + return errors + +def convert_files(file_list, new_dir, verbose=False, fix_errors=True, convert_brackets=False, updated_tagset=False, write_ids=False): + errors = defaultdict(list) + for filename in file_list: + base_name, _ = os.path.splitext(os.path.split(filename)[-1]) + new_path = os.path.join(new_dir, base_name) + new_file_path = f'{new_path}.mrg' + # Convert the tree and write to new_file_path + new_errors = convert_file(filename, new_file_path, fix_errors, convert_brackets, updated_tagset, write_ids) + for e in new_errors: + errors[e].extend(new_errors[e]) + + if len(errors.keys()) == 0: + print("All errors were fixed!") + else: + print("Found the following errors:") + keys = sorted(errors.keys()) + if verbose: + for e in keys: + print("--------- %10s -------------" % e) + print("\n\n".join(errors[e])) + print() + print() + for e in keys: + print("%s: %d" % (e, len(errors[e]))) + +def convert_dir(orig_dir, new_dir): + file_list = os.listdir(orig_dir) + # Only convert .prd files, skip the .raw files from VLSP 2009 + file_list = [os.path.join(orig_dir, f) for f in file_list if os.path.splitext(f)[1] != '.raw'] + convert_files(file_list, new_dir) + +def main(): + """ + Converts files from the 2009 version of VLSP to .mrg files + + Process args, loop through each file in the directory and convert + to the desired tree format + """ + parser = argparse.ArgumentParser( + description="Script that converts a VTB Tree into the desired format", + ) + parser.add_argument( + 'orig_dir', + help='The location of the original directory storing original trees ' + ) + parser.add_argument( + 'new_dir', + help='The location of new directory storing the new formatted trees' + ) + args = parser.parse_args() + + org_dir = args.org_dir + new_dir = args.new_dir + + convert_dir(org_dir, new_dir) + + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/constituency/vtb_split.py b/stanza/stanza/utils/datasets/constituency/vtb_split.py new file mode 100644 index 0000000000000000000000000000000000000000..0046c09717af0d4558a20943c90e91c5c6b8b439 --- /dev/null +++ b/stanza/stanza/utils/datasets/constituency/vtb_split.py @@ -0,0 +1,167 @@ +""" +From a directory of files with VTB Trees, split into train/dev/test set +with a split of 70/15/15 + +The script requires two arguments +1. org_dir: the original directory obtainable from running vtb_convert.py +2. split_dir: the directory where the train/dev/test splits will be stored +""" + +import os +import argparse +import random + + +def create_shuffle_list(org_dir): + """ + This function creates the random order with which we use to loop through the files + + :param org_dir: original directory storing the files that store the trees + :return: list of file names randomly shuffled + """ + file_names = sorted(os.listdir(org_dir)) + random.shuffle(file_names) + + return file_names + + +def create_paths(split_dir, short_name): + """ + This function creates the necessary paths for the train/dev/test splits + + :param split_dir: directory that stores the splits + :return: train path, dev path, test path + """ + if not short_name: + short_name = "" + elif not short_name.endswith("_"): + short_name = short_name + "_" + + train_path = os.path.join(split_dir, '%strain.mrg' % short_name) + dev_path = os.path.join(split_dir, '%sdev.mrg' % short_name) + test_path = os.path.join(split_dir, '%stest.mrg' % short_name) + + return train_path, dev_path, test_path + + +def get_num_samples(org_dir, file_names): + """ + Function for obtaining the number of samples + + :param org_dir: original directory storing the tree files + :param file_names: list of file names in the directory + :return: number of samples + """ + count = 0 + # Loop through the files, which then loop through the trees + for filename in file_names: + # Skip files that are not .mrg + if not filename.endswith('.mrg'): + continue + # File is .mrg. Start processing + file_dir = os.path.join(org_dir, filename) + with open(file_dir, 'r', encoding='utf-8') as reader: + content = reader.readlines() + for line in content: + count += 1 + + return count + +def split_files(org_dir, split_dir, short_name=None, train_size=0.7, dev_size=0.15, rotation=None): + os.makedirs(split_dir, exist_ok=True) + + if train_size + dev_size >= 1.0: + print("Not making a test slice with the given ratios: train {} dev {}".format(train_size, dev_size)) + + # Create a random shuffle list of the file names in the original directory + file_names = create_shuffle_list(org_dir) + + # Create train_path, dev_path, test_path + train_path, dev_path, test_path = create_paths(split_dir, short_name) + + # Set up the number of samples for each train/dev/test set + # TODO: if we ever wanted to split files with in them, + # this particular code would need some updating to pay attention to the ids + num_samples = get_num_samples(org_dir, file_names) + print("Found {} total lines in {}".format(num_samples, org_dir)) + + stop_train = int(num_samples * train_size) + if train_size + dev_size >= 1.0: + stop_dev = num_samples + output_limits = (stop_train, stop_dev) + output_names = (train_path, dev_path) + print("Splitting {} train, {} dev".format(stop_train, stop_dev - stop_train)) + elif train_size + dev_size > 0.0: + stop_dev = int(num_samples * (train_size + dev_size)) + output_limits = (stop_train, stop_dev, num_samples) + output_names = (train_path, dev_path, test_path) + print("Splitting {} train, {} dev, {} test".format(stop_train, stop_dev - stop_train, num_samples - stop_dev)) + else: + stop_dev = 0 + output_limits = (num_samples,) + output_names = (test_path,) + print("Copying all {} lines to test".format(num_samples)) + + # Count how much stuff we've written. + # We will switch to the next output file when we're written enough + count = 0 + + trees = [] + for filename in file_names: + if not filename.endswith('.mrg'): + continue + with open(os.path.join(org_dir, filename), encoding='utf-8') as reader: + new_trees = reader.readlines() + new_trees = [x.strip() for x in new_trees] + new_trees = [x for x in new_trees if x] + trees.extend(new_trees) + # rotate the train & dev sections, leave the test section the same + if rotation is not None and rotation[0] > 0: + rotation_start = len(trees) * rotation[0] // rotation[1] + rotation_end = stop_dev + # if there are no test trees, rotation_end: will be empty anyway + trees = trees[rotation_start:rotation_end] + trees[:rotation_start] + trees[rotation_end:] + tree_iter = iter(trees) + for write_path, count_limit in zip(output_names, output_limits): + with open(write_path, 'w', encoding='utf-8') as writer: + # Loop through the files, which then loop through the trees and write to write_path + while count < count_limit: + next_tree = next(tree_iter, None) + if next_tree is None: + raise RuntimeError("Ran out of trees before reading all of the expected trees") + # Write to write_dir + writer.write(next_tree) + writer.write("\n") + count += 1 + +def main(): + """ + Main function for the script + + Process args, loop through each tree in each file in the directory + and write the trees to the train/dev/test split with a split of + 70/15/15 + """ + parser = argparse.ArgumentParser( + description="Script that splits a list of files of vtb trees into train/dev/test sets", + ) + parser.add_argument( + 'org_dir', + help='The location of the original directory storing correctly formatted vtb trees ' + ) + parser.add_argument( + 'split_dir', + help='The location of new directory storing the train/dev/test set' + ) + + args = parser.parse_args() + + org_dir = args.org_dir + split_dir = args.split_dir + + random.seed(1234) + + split_files(org_dir, split_dir) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/ner/combine_ner_datasets.py b/stanza/stanza/utils/datasets/ner/combine_ner_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..71bd2b3ce364556df94ff903e64e85a2f6b841cf --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/combine_ner_datasets.py @@ -0,0 +1,25 @@ +import argparse + +from stanza.utils.default_paths import get_default_paths +from stanza.utils.datasets.ner.utils import combine_dataset + +SHARDS = ("train", "dev", "test") + +def main(args=None): + ner_data_dir = get_default_paths()['NER_DATA_DIR'] + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dataset', type=str, help='What dataset to output') + parser.add_argument('input_datasets', type=str, nargs='+', help='Which datasets to input') + + parser.add_argument('--input_dir', type=str, default=ner_data_dir, help='Which directory to find the datasets') + parser.add_argument('--output_dir', type=str, default=ner_data_dir, help='Which directory to write the dataset') + args = parser.parse_args(args) + + input_dir = args.input_dir + output_dir = args.output_dir + + combine_dataset(input_dir, output_dir, args.input_datasets, args.output_dataset) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/ner/convert_amt.py b/stanza/stanza/utils/datasets/ner/convert_amt.py new file mode 100644 index 0000000000000000000000000000000000000000..7f59711ab826f84d567d213cfbda843d5f03029b --- /dev/null +++ b/stanza/stanza/utils/datasets/ner/convert_amt.py @@ -0,0 +1,236 @@ +""" +Converts a .json file from AMT to a .bio format and then a .json file + +To ignore Facility and Product, turn NORP into miscellaneous: + + python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=Miscellaneous + +To turn all labels into the 4 class used in conll03: + + python3 stanza/utils/datasets/ner/convert_amt.py --input_path /u/nlp/data/ner/stanza/en_amt/output.manifest --ignore Product,Facility --remap NORP=MISC,Miscellaneous=MISC,Location=LOC,Person=PER,Organization=ORG +""" + +import argparse +import copy +import json +from operator import itemgetter +import sys + +from tqdm import tqdm + +import stanza +from stanza.utils.datasets.ner.utils import write_sentences +import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file + +def read_json(input_filename): + """ + Read the json file and extract the NER labels + + Will not return lines which are not labeled + + Return format is a list of lines + where each line is a tuple: (text, labels) + labels is a list of maps, {'label':..., 'startOffset':..., 'endOffset':...} + """ + docs = [] + blank = 0 + unlabeled = 0 + broken = 0 + with open(input_filename, encoding="utf-8") as fin: + for line_idx, line in enumerate(fin): + doc = json.loads(line) + if sorted(doc.keys()) == ['source']: + unlabeled += 1 + continue + if 'source' not in doc: + blank += 1 + continue + source = doc['source'] + entities = None + for k in doc.keys(): + if k == 'source' or k.endswith('metadata'): + continue + if 'annotations' not in doc[k]: + continue + annotations = doc[k]['annotations'] + if 'entities' not in annotations: + continue + if 'entities' in annotations: + if entities is not None: + raise ValueError("Found a map with multiple annotations at line %d" % line_idx) + entities = annotations['entities'] + # entities is now a map such as + # [{'label': 'Location', 'startOffset': 0, 'endOffset': 6}, + # {'label': 'Location', 'startOffset': 11, 'endOffset': 23}, + # {'label': 'NORP', 'startOffset': 66, 'endOffset': 74}, + # {'label': 'NORP', 'startOffset': 191, 'endOffset': 214}] + if entities is None: + unlabeled += 1 + continue + is_broken = any(any(x not in entity for x in ('label', 'startOffset', 'endOffset')) + for entity in entities) + if is_broken: + broken += 1 + if broken == 1: + print("Found an entity which was missing either label, startOffset, or endOffset") + print(entities) + docs.append((source, entities)) + + print("Found %d labeled lines. %d lines were blank, %d lines were broken, and %d lines were unlabeled" % (len(docs), blank, broken, unlabeled)) + return docs + +def remove_ignored_labels(docs, ignored): + if not ignored: + return docs + + ignored = set(ignored.split(",")) + # drop all labels which match something in ignored + # otherwise leave everything the same + new_docs = [(doc[0], [x for x in doc[1] if x['label'] not in ignored]) + for doc in docs] + return new_docs + +def remap_labels(docs, remap): + if not remap: + return docs + + remappings = {} + for remapping in remap.split(","): + pieces = remapping.split("=") + remappings[pieces[0]] = pieces[1] + + print(remappings) + + new_docs = [] + for doc in docs: + entities = copy.deepcopy(doc[1]) + for entity in entities: + entity['label'] = remappings.get(entity['label'], entity['label']) + new_doc = (doc[0], entities) + new_docs.append(new_doc) + return new_docs + +def remove_nesting(docs): + """ + Currently the NER tool does not handle nesting, so we just throw away nested entities + + In the event of entites which exactly overlap, the first one in the list wins + """ + new_docs = [] + nested = 0 + exact = 0 + total = 0 + for doc in docs: + source, labels = doc + # sort by startOffset, -endOffset + labels = sorted(labels, key=lambda x: (x['startOffset'], -x['endOffset'])) + new_labels = [] + for label in labels: + total += 1 + # note that this works trivially for an empty list + for other in reversed(new_labels): + if label['startOffset'] == other['startOffset'] and label['endOffset'] == other['endOffset']: + exact += 1 + break + if label['startOffset'] < other['endOffset']: + #print("Ignoring nested entity: {} |{}| vs {} |{}|".format(label, source[label['startOffset']:label['endOffset']], other, source[other['startOffset']:other['endOffset']])) + nested += 1 + break + else: # yes, this is meant to be a for-else + new_labels.append(label) + + new_docs.append((source, new_labels)) + print("Ignored %d exact and %d nested labels out of %d entries" % (exact, nested, total)) + return new_docs + +def process_doc(source, labels, pipe): + """ + Given a source text and a list of labels, tokenize the text, then assign labels based on the spans defined + """ + doc = pipe(source) + sentences = doc.sentences + for sentence in sentences: + for token in sentence.tokens: + token.ner = "O" + + for label in labels: + ner = label['label'] + start_offset = label['startOffset'] + end_offset = label['endOffset'] + for sentence in sentences: + if (sentence.tokens[0].start_char <= start_offset and + sentence.tokens[-1].end_char >= end_offset): + # found the sentence! + break + else: # for-else again! deal with it + continue + + start_token = None + end_token = None + for token_idx, token in enumerate(sentence.tokens): + if token.start_char <= start_offset and token.end_char > start_offset: + # ideally we'd have start_char == start_offset, but maybe our + # tokenization doesn't match the tokenization of the annotators + start_token = token + start_token.ner = "B-" + ner + elif start_token is not None: + if token.start_char >= end_offset and token_idx > 0: + end_token = sentence.tokens[token_idx-1] + break + if token.end_char == end_offset and token_idx > 0 and token.text in (',', '.'): + end_token = sentence.tokens[token_idx-1] + break + token.ner = "I-" + ner + if token.end_char >= end_offset and end_token is None: + end_token = token + break + if start_token is None or end_token is None: + raise AssertionError("This should not happen") + + return [[(token.text, token.ner) for token in sentence.tokens] for sentence in sentences] + + + +def main(args): + """ + Read in a .json file of labeled data from AMT, write out a converted .bio file + + Enforces that there is only one set of labels on a sentence + (TODO: add an option to skip certain sets of labels) + """ + docs = read_json(args.input_path) + + if len(docs) == 0: + print("Error: no documents found in the input file!") + return + + docs = remove_ignored_labels(docs, args.ignore) + docs = remap_labels(docs, args.remap) + docs = remove_nesting(docs) + + pipe = stanza.Pipeline(args.language, processors="tokenize") + sentences = [] + for doc in tqdm(docs): + sentences.extend(process_doc(*doc, pipe)) + print("Found %d total sentences (may be more than #docs if a doc has more than one sentence)" % len(sentences)) + bio_filename = args.output_path + write_sentences(args.output_path, sentences) + print("Sentences written to %s" % args.output_path) + if bio_filename.endswith(".bio"): + json_filename = bio_filename[:-4] + ".json" + else: + json_filename = bio_filename + ".json" + prepare_ner_file.process_dataset(bio_filename, json_filename) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--language', type=str, default="en", help="Language to process") + parser.add_argument('--input_path', type=str, default="output.manifest", help="Where to find the files") + parser.add_argument('--output_path', type=str, default="data/ner/en_amt.test.bio", help="Where to output the results") + parser.add_argument('--json_output_path', type=str, default=None, help="Where to output .json. Best guess will be made if there is no .json file") + parser.add_argument('--ignore', type=str, default=None, help="Ignore these labels: comma separated list without B- or I-") + parser.add_argument('--remap', type=str, default=None, help="Remap labels: comma separated list of X=Y") + args = parser.parse_args() + + main(args) diff --git a/stanza/stanza/utils/datasets/pretrain/word_in_pretrain.py b/stanza/stanza/utils/datasets/pretrain/word_in_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..81b1a968fead2d2a2817f976f325a128dcdc94d5 --- /dev/null +++ b/stanza/stanza/utils/datasets/pretrain/word_in_pretrain.py @@ -0,0 +1,32 @@ +""" +Simple tool to query a word vector file to see if certain words are in that file +""" + +import argparse +import os + +from stanza.models.common.pretrain import Pretrain +from stanza.resources.common import DEFAULT_MODEL_DIR, download + +def main(): + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--pretrain", default=None, type=str, help="Where to read the converted PT file") + group.add_argument("--package", default=None, type=str, help="Use a pretrain package instead") + parser.add_argument("--download_json", default=False, action='store_true', help="Download the json even if it already exists") + parser.add_argument("words", type=str, nargs="+", help="Which words to search for") + args = parser.parse_args() + + if args.pretrain: + pt = Pretrain(args.pretrain) + else: + lang, package = args.package.split("_", 1) + download(lang=lang, package=None, processors={"pretrain": package}, download_json=args.download_json) + pt_filename = os.path.join(DEFAULT_MODEL_DIR, lang, "pretrain", "%s.pt" % package) + pt = Pretrain(pt_filename) + + for word in args.words: + print("{}: {}".format(word, word in pt.vocab)) + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/datasets/sentiment/__init__.py b/stanza/stanza/utils/datasets/sentiment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stanza/stanza/utils/datasets/sentiment/add_constituency.py b/stanza/stanza/utils/datasets/sentiment/add_constituency.py new file mode 100644 index 0000000000000000000000000000000000000000..ea22f091eb4a86342b2dcad9a52395a6453151ec --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/add_constituency.py @@ -0,0 +1,106 @@ +""" +For a dataset produced by prepare_sentiment_dataset, add constituency parses. + +Obviously this will only work on languages that have a constituency parser +""" + +import argparse +import os + +import stanza +from stanza.models.classifiers.data import read_dataset +from stanza.models.classifiers.utils import WVType +from stanza.models.mwt.utils import resplit_mwt +from stanza.utils.datasets.sentiment import prepare_sentiment_dataset +from stanza.utils.datasets.sentiment import process_utils +import stanza.utils.default_paths as default_paths + +SHARDS = ("train", "dev", "test") + +def main(): + parser = argparse.ArgumentParser() + # TODO: allow multiple files? + parser.add_argument('dataset', type=str, help="Dataset (or a single file) to process") + parser.add_argument('--output', type=str, help="Write the processed data here instead of clobbering") + parser.add_argument('--constituency_package', type=str, default=None, help="Constituency model to use for parsing") + parser.add_argument('--constituency_model', type=str, default=None, help="Specific model file to use for parsing") + parser.add_argument('--retag_package', type=str, default=None, help="Which tagger to use for retagging") + parser.add_argument('--split_mwt', action='store_true', help="Split MWT from the original sentences if the language has MWT") + parser.add_argument('--lang', type=str, default=None, help="Which language the dataset/file is in. If not specified, will try to use the dataset name") + args = parser.parse_args() + + if os.path.exists(args.dataset): + expected_files = [args.dataset] + if args.output: + output_files = [args.output] + else: + output_files = expected_files + if not args.lang: + _, filename = os.path.split(args.dataset) + args.lang = filename.split("_")[0] + print("Guessing lang=%s based on the filename %s" % (args.lang, filename)) + else: + paths = default_paths.get_default_paths() + # TODO: one of the side effects of the tass2020 dataset is to make a bunch of extra files + # Perhaps we could have the prepare_sentiment_dataset script return a list of those files + expected_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.dataset, shard)) for shard in SHARDS] + if args.output: + output_files = [os.path.join(paths['SENTIMENT_DATA_DIR'], '%s.%s.json' % (args.output, shard)) for shard in SHARDS] + else: + output_files = expected_files + for filename in expected_files: + if not os.path.exists(filename): + print("Cannot find expected dataset file %s - rebuilding dataset" % filename) + prepare_sentiment_dataset.main(args.dataset) + break + if not args.lang: + args.lang, _ = args.dataset.split("_", 1) + print("Guessing lang=%s based on the dataset name" % args.lang) + + + pipeline_args = {"lang": args.lang, + "processors": "tokenize,pos,constituency", + "tokenize_pretokenized": True, + "pos_batch_size": 50, + "pos_tqdm": True, + "constituency_tqdm": True} + package = {} + if args.constituency_package is not None: + package["constituency"] = args.constituency_package + if args.retag_package is not None: + package["pos"] = args.retag_package + if package: + pipeline_args["package"] = package + if args.constituency_model is not None: + pipeline_args["constituency_model_path"] = args.constituency_model + pipe = stanza.Pipeline(**pipeline_args) + + if args.split_mwt: + # TODO: allow for different tokenize packages + mwt_pipe = stanza.Pipeline(lang=args.lang, processors="tokenize") + if "mwt" in mwt_pipe.processors: + print("This language has MWT. Will resplit any MWTs found in the dataset") + else: + print("--split_mwt was requested, but %s does not support MWT!" % args.lang) + args.split_mwt = False + + for filename, output_filename in zip(expected_files, output_files): + dataset = read_dataset(filename, WVType.OTHER, 1) + text = [x.text for x in dataset] + if args.split_mwt: + print("Resplitting MWT in %d sentences from %s" % (len(dataset), filename)) + doc = resplit_mwt(text, mwt_pipe) + print("Parsing %d sentences from %s" % (len(dataset), filename)) + doc = pipe(doc) + else: + print("Parsing %d sentences from %s" % (len(dataset), filename)) + doc = pipe(text) + + assert len(dataset) == len(doc.sentences) + for datum, sentence in zip(dataset, doc.sentences): + datum.constituency = sentence.constituency + + process_utils.write_list(output_filename, dataset) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/sentiment/convert_italian_poetry_classification.py b/stanza/stanza/utils/datasets/sentiment/convert_italian_poetry_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..43bc9ca9f41166442e51a2c6cc5d6b0622591be1 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/convert_italian_poetry_classification.py @@ -0,0 +1,64 @@ +""" +A short tool to turn a labeled dataset of the format +Prof. Delmonte provided into a stanza input file for the classifier. + +Data is expected to be in the sentiment italian subdirectory (see below) + +Only writes a test set. Use it as an eval file for a trained model. +""" + +import os + +import stanza +from stanza.models.classifiers.data import SentimentDatum +from stanza.utils.datasets.sentiment import process_utils +import stanza.utils.default_paths as default_paths + +def main(): + paths = default_paths.get_default_paths() + + dataset_name = "it_vit_sentences_poetry" + + poetry_filename = os.path.join(paths["SENTIMENT_BASE"], "italian", "sentence_classification", "poetry", "testset_300_labeled.txt") + if not os.path.exists(poetry_filename): + raise FileNotFoundError("Expected to find the labeled file in %s" % poetry_filename) + print("Reading the labeled poetry from %s" % poetry_filename) + + tokenizer = stanza.Pipeline("it", processors="tokenize", tokenize_no_ssplit=True) + dataset = [] + with open(poetry_filename, encoding="utf-8") as fin: + for line_num, line in enumerate(fin): + line = line.strip() + if not line: + continue + + line = line.replace(u'\ufeff', '') + pieces = line.split(maxsplit=1) + # first column is the label + # remainder of the text is the raw text + label = pieces[0].strip() + if label not in ('0', '1'): + if label == "viene" and line_num == 257: + print("Skipping known missing label at line 257") + continue + assert isinstance(label, str) + ords = ",".join(str(ord(x)) for x in label) + raise ValueError("Unexpected label |%s| (%s) for line %d" % (label, ords, line_num)) + + # tokenize the text into words + # we could make this faster by stacking it, but the input file is quite short anyway + text = pieces[1] + doc = tokenizer(text) + words = [x.text for x in doc.sentences[0].words] + + dataset.append(SentimentDatum(label, words)) + + print("Read %d lines from %s" % (len(dataset), poetry_filename)) + output_filename = "%s.test.json" % dataset_name + output_path = os.path.join(paths["SENTIMENT_DATA_DIR"], output_filename) + print("Writing output to %s" % output_path) + process_utils.write_list(output_path, dataset) + + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/sentiment/convert_italian_sentence_classification.py b/stanza/stanza/utils/datasets/sentiment/convert_italian_sentence_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..68d4df414eb4e32581503d87b94f195cf55e9c55 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/convert_italian_sentence_classification.py @@ -0,0 +1,85 @@ +""" +Converts a file of labels on constituency trees for the it_vit dataset + +The labels are for whether or not a sentence is written in a standard +S-V-O order. The intent is to see how much a constituency parser +can improve over a regular transformer classifier. + +This file is provided by Prof. Delmonte as part of a classification +project. Contact John Bauer for more details. + +Technically this should be "classifier" instead of "sentiment" +""" + +import os + +from stanza.models.classifiers.data import SentimentDatum +from stanza.utils.datasets.sentiment import process_utils +from stanza.utils.datasets.constituency.convert_it_vit import read_updated_trees +import stanza.utils.default_paths as default_paths + +def label_trees(label_map, trees): + new_trees = [] + for tree in trees: + if tree.con_id not in label_map: + raise ValueError("%s not labeled" % tree.con_id) + label = label_map[tree.con_id] + new_trees.append(SentimentDatum(label, tree.tree.leaf_labels(), tree.tree)) + return new_trees + +def read_label_map(label_filename): + with open(label_filename, encoding="utf-8") as fin: + lines = fin.readlines() + lines = [x.strip() for x in lines] + lines = [x.split() for x in lines if x] + label_map = {} + for line_idx, line in enumerate(lines): + k = line[0].split("#")[1] + v = line[1] + + # compensate for an off-by-one error in the labels for ids 12 through 129 + # we went back and forth a few times but i couldn't explain the error, + # so whatever, just compensate for it on the conversion side + k_idx = int(k.split("_")[1]) + if k_idx != line_idx + 1: + if k_idx >= 12 and k_idx <= 129: + k = "sent_%05d" % (k_idx - 1) + else: + raise ValueError("Unexpected key offset for line {}: {}".format(line_idx, line)) + + if v == "neg": + v = "0" + elif v == "pos": + v = "1" + else: + raise ValueError("Unexpected label %s for key %s" % (v, k)) + + if k in label_map: + raise ValueError("Duplicate key %s: new value %s, old value %s" % (k, v, label_map[k])) + label_map[k] = v + + return label_map + +def main(): + paths = default_paths.get_default_paths() + + dataset_name = "it_vit_sentences" + + label_filename = os.path.join(paths["SENTIMENT_BASE"], "italian", "sentence_classification", "classified") + if not os.path.exists(label_filename): + raise FileNotFoundError("Expected to find the labeled file in %s" % label_filename) + + label_map = read_label_map(label_filename) + + # this will produce three lists of trees with their con_id attached + train_trees, dev_trees, test_trees = read_updated_trees(paths) + + train_trees = label_trees(label_map, train_trees) + dev_trees = label_trees(label_map, dev_trees) + test_trees = label_trees(label_map, test_trees) + + dataset = (train_trees, dev_trees, test_trees) + process_utils.write_dataset(dataset, paths["SENTIMENT_DATA_DIR"], dataset_name) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/sentiment/process_airline.py b/stanza/stanza/utils/datasets/sentiment/process_airline.py new file mode 100644 index 0000000000000000000000000000000000000000..edacb874ed99653b6fc0f5adde34b14c3f9fb636 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_airline.py @@ -0,0 +1,78 @@ +""" +Airline tweets from Kaggle +from https://www.kaggle.com/crowdflower/twitter-airline-sentiment/data# +Some ratings seem questionable, but it doesn't hurt performance much, if at all + +Files in the airline repo are csv, with quotes in "..." if they contained commas themselves. + +Accordingly, we use the csv module to read the files and output them in the format + + +Run using + +python3 convert_airline.py Tweets.csv train.json + +If the first word is an @, it is removed, and after that, leading @ or # are removed. +For example: + +@AngledLuffa you must hate having Mox Opal #banned +-> +you must hate having Mox Opal banned +""" + +import csv +import os +import sys + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +def get_phrases(in_directory): + in_filename = os.path.join(in_directory, "Tweets.csv") + with open(in_filename, newline='') as fin: + cin = csv.reader(fin, delimiter=',', quotechar='"') + lines = list(cin) + + phrases = [] + for line in lines[1:]: + sentiment = line[1] + if sentiment == 'negative': + sentiment = '0' + elif sentiment == 'neutral': + sentiment = '1' + elif sentiment == 'positive': + sentiment = '2' + else: + raise ValueError("Unknown sentiment: {}".format(sentiment)) + # some of the tweets have \n in them + utterance = line[10].replace("\n", " ") + phrases.append(SentimentDatum(sentiment, utterance)) + + return phrases + +def get_tokenized_phrases(in_directory): + phrases = get_phrases(in_directory) + phrases = process_utils.get_ptb_tokenized_phrases(phrases) + phrases = [SentimentDatum(x.sentiment, process_utils.clean_tokenized_tweet(x.text)) for x in phrases] + print("Found {} phrases in the airline corpus".format(len(phrases))) + return phrases + +def main(in_directory, out_directory, short_name): + phrases = get_tokenized_phrases(in_directory) + + os.makedirs(out_directory, exist_ok=True) + out_filename = os.path.join(out_directory, "%s.train.json" % short_name) + # filter leading @United, @American, etc from the tweets + process_utils.write_list(out_filename, phrases) + + # something like this would count @s if you cared enough to count + # would need to update for SentimentDatum() + #ats = Counter() + #for line in lines: + # ats.update([x for x in line.split() if x[0] == '@']) + +if __name__ == '__main__': + in_directory = sys.argv[1] + out_directory = sys.argv[2] + short_name = sys.argv[3] + main(in_directory, out_directory, short_name) diff --git a/stanza/stanza/utils/datasets/sentiment/process_arguana_xml.py b/stanza/stanza/utils/datasets/sentiment/process_arguana_xml.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3a5937f380188928821fcad1f8d49251d7acd1 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_arguana_xml.py @@ -0,0 +1,76 @@ +from collections import namedtuple +import glob +import os +import sys +import xml.etree.ElementTree as ET + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +ArguanaSentimentDatum = namedtuple('ArguanaSentimentDatum', ['begin', 'end', 'rating']) + +""" +Extracts positive, neutral, and negative phrases from the ArguAna hotel review corpus + +Run as follows: + +python3 parse_arguana_xml.py split/training data/sentiment + +ArguAna can be downloaded here: + +http://argumentation.bplaced.net/arguana/data +http://argumentation.bplaced.net/arguana-data/arguana-tripadvisor-annotated-v2.zip +""" + +def get_phrases(filename): + tree = ET.parse(filename) + fragments = [] + + root = tree.getroot() + body = None + for child in root: + if child.tag == '{http:///uima/cas.ecore}Sofa': + body = child.attrib['sofaString'] + elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Fact': + fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']), + end=int(child.attrib['end']), + rating="1")) + elif child.tag == '{http:///de/aitools/ie/uima/type/arguana.ecore}Opinion': + if child.attrib['polarity'] == 'negative': + rating = "0" + elif child.attrib['polarity'] == 'positive': + rating = "2" + else: + raise ValueError("Unexpected polarity found in {}".format(filename)) + fragments.append(ArguanaSentimentDatum(begin=int(child.attrib['begin']), + end=int(child.attrib['end']), + rating=rating)) + + + phrases = [SentimentDatum(fragment.rating, body[fragment.begin:fragment.end]) for fragment in fragments] + #phrases = [phrase.replace("\n", " ") for phrase in phrases] + return phrases + +def get_phrases_from_directory(directory): + phrases = [] + inpath = os.path.join(directory, "arguana-tripadvisor-annotated-v2", "split", "training", "*", "*xmi") + for filename in glob.glob(inpath): + phrases.extend(get_phrases(filename)) + return phrases + +def get_tokenized_phrases(in_directory): + phrases = get_phrases_from_directory(in_directory) + phrases = process_utils.get_ptb_tokenized_phrases(phrases) + print("Found {} phrases in arguana".format(len(phrases))) + return phrases + +def main(in_directory, out_directory, short_name): + phrases = get_tokenized_phrases(in_directory) + process_utils.write_list(os.path.join(out_directory, "%s.train.json" % short_name), phrases) + + +if __name__ == "__main__": + in_directory = sys.argv[1] + out_directory = sys.argv[2] + short_name = sys.argv[3] + main(in_directory, out_directory, short_name) diff --git a/stanza/stanza/utils/datasets/sentiment/process_es_tass2020.py b/stanza/stanza/utils/datasets/sentiment/process_es_tass2020.py new file mode 100644 index 0000000000000000000000000000000000000000..ccba5ffe2e8934939d5338a22f459e61748a638e --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_es_tass2020.py @@ -0,0 +1,248 @@ +""" +Convert the TASS 2020 dataset, available here: http://tass.sepln.org/2020/?page_id=74 + +There are two parts to the dataset, but only part 1 has the gold +annotations available. + +Download: +Task 1 train & dev sets +Task 1.1 test set +Task 1.2 test set +Task 1.1 test set gold standard +Task 1.2 test set gold standard (.tsv, not .zip) + +No need to unzip any of the files. The extraction script reads the +expected paths directly from the zip files. + +There are two subtasks in TASS 2020. One is split among 5 Spanish +speaking countries, and the other is combined across all of the +countries. Here we combine all of the data into one output file. + +Also, each of the subparts are output into their own files, such as +p2.json, p1.mx.json, etc +""" + +import os +import zipfile + +import stanza + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.default_paths as default_paths +from stanza.utils.datasets.sentiment.process_utils import write_dataset, write_list + +def convert_label(label): + """ + N/NEU/P or error + """ + if label == "N": + return 0 + if label == "NEU": + return 1 + if label == "P": + return 2 + raise ValueError("Unexpected label %s" % label) + +def read_test_labels(fin): + """ + Read a tab (or space) separated list of id/label pairs + """ + label_map = {} + for line_idx, line in enumerate(fin): + if isinstance(line, bytes): + line = line.decode("utf-8") + pieces = line.split() + if len(pieces) < 2: + continue + if len(pieces) > 2: + raise ValueError("Unexpected format at line %d: all label lines should be len==2\n%s" % (line_idx, line)) + + datum_id, label = pieces + try: + label = convert_label(label) + except ValueError: + raise ValueError("Unexpected test label %s at line %d\n%s" % (label, line_idx, line)) + + label_map[datum_id] = label + return label_map + +def open_read_test_labels(filename, zip_filename=None): + """ + Open either a text or zip file, then read the labels + """ + if zip_filename is None: + with open(filename, encoding="utf-8") as fin: + test_labels = read_test_labels(fin) + print("Read %d lines from %s" % (len(test_labels), filename)) + return test_labels + + with zipfile.ZipFile(zip_filename) as zin: + with zin.open(filename) as fin: + test_labels = read_test_labels(fin) + print("Read %d lines from %s - %s" % (len(test_labels), zip_filename, filename)) + return test_labels + + +def read_sentences(fin): + """ + Read ids and text from the given file + """ + lines = [] + for line_idx, line in enumerate(fin): + line = line.decode("utf-8") + pieces = line.split(maxsplit=1) + if len(pieces) < 2: + continue + lines.append(pieces) + return lines + +def open_read_sentences(filename, zip_filename): + """ + Opens a file and then reads the sentences + + Only applies to files inside zips, as all of the sentence files in + this dataset are inside a zip + """ + with zipfile.ZipFile(zip_filename) as zin: + with zin.open(filename) as fin: + test_sentences = read_sentences(fin) + print("Read %d texts from %s - %s" % (len(test_sentences), zip_filename, filename)) + + return test_sentences + +def combine_test_set(sentences, labels): + """ + Combines the labels and sentences from two pieces of the test set + + Matches the ID from the label files and the text files + """ + combined = [] + if len(sentences) != len(labels): + raise ValueError("Lengths of sentences and labels should match!") + for sent_id, text in sentences: + label = labels.get(sent_id, None) + if label is None: + raise KeyError("Cannot find a test label from the ID: %s" % sent_id) + # not tokenized yet - we can do tokenization in batches + combined.append(SentimentDatum(label, text)) + return combined + +DATASET_PIECES = ("cr", "es", "mx", "pe", "uy") + +def tokenize(sentiment_data, pipe): + """ + Takes a list of (label, text) and returns a list of SentimentDatum with tokenized text + + Only the first 'sentence' is used - ideally the pipe has ssplit turned off + """ + docs = [x.text for x in sentiment_data] + in_docs = [stanza.Document([], text=d) for d in docs] + out_docs = pipe(in_docs) + + sentiment_data = [SentimentDatum(datum.sentiment, + [y.text for y in doc.sentences[0].tokens]) # list of text tokens for each doc + for datum, doc in zip(sentiment_data, out_docs)] + + return sentiment_data + +def read_test_set(label_zip_filename, label_filename, sentence_zip_filename, sentence_filename, pipe): + """ + Read and tokenize an entire test set given the label and sentence filenames + """ + test_labels = open_read_test_labels(label_filename, label_zip_filename) + test_sentences = open_read_sentences(sentence_filename, sentence_zip_filename) + sentiment_data = combine_test_set(test_sentences, test_labels) + return tokenize(sentiment_data, pipe) + + return sentiment_data + +def read_train_file(zip_filename, filename, pipe): + """ + Read and tokenize a train set + + All of the train data is inside one zip. We read it one piece at a time + """ + sentiment_data = [] + with zipfile.ZipFile(zip_filename) as zin: + with zin.open(filename) as fin: + for line_idx, line in enumerate(fin): + if isinstance(line, bytes): + line = line.decode("utf-8") + pieces = line.split(maxsplit=1) + if len(pieces) < 2: + continue + pieces = pieces[1].rsplit(maxsplit=1) + if len(pieces) < 2: + continue + text, label = pieces + try: + label = convert_label(label) + except ValueError: + raise ValueError("Unexpected train label %s at line %d\n%s" % (label, line_idx, line)) + sentiment_data.append(SentimentDatum(label, text)) + + print("Read %d texts from %s - %s" % (len(sentiment_data), zip_filename, filename)) + sentiment_data = tokenize(sentiment_data, pipe) + return sentiment_data + +def convert_tass2020(in_directory, out_directory, dataset_name): + """ + Read all of the data from in_directory/spanish/tass2020, write it to out_directory/dataset_name... + """ + in_directory = os.path.join(in_directory, "spanish", "tass2020") + + pipe = stanza.Pipeline(lang="es", processors="tokenize", tokenize_no_ssplit=True) + + test_11 = {} + test_11_labels_zip = os.path.join(in_directory, "tass2020-test-gold.zip") + test_11_sentences_zip = os.path.join(in_directory, "Test1.1.zip") + for piece in DATASET_PIECES: + inner_label_filename = piece + ".tsv" + inner_sentence_filename = os.path.join("Test1.1", piece.upper() + ".tsv") + test_11[piece] = read_test_set(test_11_labels_zip, inner_label_filename, + test_11_sentences_zip, inner_sentence_filename, pipe) + + test_12_label_filename = os.path.join(in_directory, "task1.2-test-gold.tsv") + test_12_sentences_zip = os.path.join(in_directory, "test1.2.zip") + test_12_sentences_filename = "test1.2/task1.2.tsv" + test_12 = read_test_set(None, test_12_label_filename, + test_12_sentences_zip, test_12_sentences_filename, pipe) + + train_dev_zip = os.path.join(in_directory, "Task1-train-dev.zip") + dev = {} + train = {} + for piece in DATASET_PIECES: + dev_filename = os.path.join("dev", piece + ".tsv") + dev[piece] = read_train_file(train_dev_zip, dev_filename, pipe) + + for piece in DATASET_PIECES: + train_filename = os.path.join("train", piece + ".tsv") + train[piece] = read_train_file(train_dev_zip, train_filename, pipe) + + all_test = test_12 + [item for piece in test_11.values() for item in piece] + all_dev = [item for piece in dev.values() for item in piece] + all_train = [item for piece in train.values() for item in piece] + + print("Total train items: %8d" % len(all_train)) + print("Total dev items: %8d" % len(all_dev)) + print("Total test items: %8d" % len(all_test)) + + write_dataset((all_train, all_dev, all_test), out_directory, dataset_name) + + output_file = os.path.join(out_directory, "%s.test.p2.json" % dataset_name) + write_list(output_file, test_12) + + for piece in DATASET_PIECES: + output_file = os.path.join(out_directory, "%s.test.p1.%s.json" % (dataset_name, piece)) + write_list(output_file, test_11[piece]) + +def main(paths): + in_directory = paths['SENTIMENT_BASE'] + out_directory = paths['SENTIMENT_DATA_DIR'] + + convert_tass2020(in_directory, out_directory, "es_tass2020") + + +if __name__ == '__main__': + paths = default_paths.get_default_paths() + main(paths) diff --git a/stanza/stanza/utils/datasets/sentiment/process_ren_chinese.py b/stanza/stanza/utils/datasets/sentiment/process_ren_chinese.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb82f27f4e9fb666b3c0edc64bd82cf2f1a1bb5 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_ren_chinese.py @@ -0,0 +1,93 @@ +import glob +import os +import random +import sys + +import xml.etree.ElementTree as ET + +from collections import namedtuple + +import stanza + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +""" +This processes a Chinese corpus, hosted here: + +http://a1-www.is.tokushima-u.ac.jp/member/ren/Ren-CECps1.0/Ren-CECps1.0.html + +The authors want a signed document saying you won't redistribute the corpus. + +The corpus format is a bunch of .xml files, with sentences labeled with various emotions and an overall polarity. Polarity is labeled as follows: + +消极: negative +中性: neutral +积极: positive +""" + +def get_phrases(filename): + tree = ET.parse(filename) + fragments = [] + + root = tree.getroot() + for child in root: + if child.tag == 'paragraph': + for subchild in child: + if subchild.tag == 'sentence': + text = subchild.attrib['S'].strip() + if len(text) <= 2: + continue + polarity = None + for inner in subchild: + if inner.tag == 'Polarity': + polarity = inner + break + if polarity is None: + print("Found sentence with no polarity in {}: {}".format(filename, text)) + continue + if polarity.text == '消极': + sentiment = "0" + elif polarity.text == '中性': + sentiment = "1" + elif polarity.text == '积极': + sentiment = "2" + else: + raise ValueError("Unknown polarity {} in {}".format(polarity.text, filename)) + fragments.append(SentimentDatum(sentiment, text)) + + return fragments + +def read_snippets(xml_directory): + sentences = [] + for filename in glob.glob(xml_directory + '/xml/cet_*xml'): + sentences.extend(get_phrases(filename)) + + nlp = stanza.Pipeline('zh', processors='tokenize') + snippets = [] + for sentence in sentences: + doc = nlp(sentence.text) + text = [token.text for sentence in doc.sentences for token in sentence.tokens] + snippets.append(SentimentDatum(sentence.sentiment, text)) + random.shuffle(snippets) + return snippets + +def main(xml_directory, out_directory, short_name): + snippets = read_snippets(xml_directory) + + print("Found {} phrases".format(len(snippets))) + os.makedirs(out_directory, exist_ok=True) + process_utils.write_splits(out_directory, + snippets, + (process_utils.Split("%s.train.json" % short_name, 0.8), + process_utils.Split("%s.dev.json" % short_name, 0.1), + process_utils.Split("%s.test.json" % short_name, 0.1))) + + +if __name__ == "__main__": + random.seed(1234) + xml_directory = sys.argv[1] + out_directory = sys.argv[2] + short_name = sys.argv[3] + main(xml_directory, out_directory, short_name) + diff --git a/stanza/stanza/utils/datasets/sentiment/process_slsd.py b/stanza/stanza/utils/datasets/sentiment/process_slsd.py new file mode 100644 index 0000000000000000000000000000000000000000..b370f4905b82a9e407d69b6e55bb38d61a9bd78f --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_slsd.py @@ -0,0 +1,71 @@ +""" +A small dataset of 1500 positive and 1500 negative sentences. +Supposedly has no neutral sentences by design + +https://archive.ics.uci.edu/ml/datasets/Sentiment+Labelled+Sentences + +https://archive.ics.uci.edu/ml/machine-learning-databases/00331/ + +See the existing readme for citation requirements etc + +Files in the slsd repo were one line per annotation, with labels 0 +for negative and 1 for positive. No neutral labels existed. + +Accordingly, we rearrange the text and adjust the label to fit the +0/1/2 paradigm. Text is retokenized using PTBTokenizer. + + + +process_slsd.py +""" + +import os +import sys + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +def get_phrases(in_directory): + in_filenames = [os.path.join(in_directory, 'amazon_cells_labelled.txt'), + os.path.join(in_directory, 'imdb_labelled.txt'), + os.path.join(in_directory, 'yelp_labelled.txt')] + + lines = [] + for filename in in_filenames: + lines.extend(open(filename, newline='')) + + phrases = [] + for line in lines: + line = line.strip() + sentiment = line[-1] + utterance = line[:-1] + utterance = utterance.replace("!.", "!") + utterance = utterance.replace("?.", "?") + if sentiment == '0': + sentiment = '0' + elif sentiment == '1': + sentiment = '2' + else: + raise ValueError("Unknown sentiment: {}".format(sentiment)) + phrases.append(SentimentDatum(sentiment, utterance)) + + return phrases + +def get_tokenized_phrases(in_directory): + phrases = get_phrases(in_directory) + phrases = process_utils.get_ptb_tokenized_phrases(phrases) + print("Found %d phrases in slsd" % len(phrases)) + return phrases + +def main(in_directory, out_directory, short_name): + phrases = get_tokenized_phrases(in_directory) + out_filename = os.path.join(out_directory, "%s.train.json" % short_name) + os.makedirs(out_directory, exist_ok=True) + process_utils.write_list(out_filename, phrases) + + +if __name__ == '__main__': + in_directory = sys.argv[1] + out_directory = sys.argv[2] + short_name = sys.argv[3] + main(in_directory, out_directory, short_name) diff --git a/stanza/stanza/utils/datasets/sentiment/process_sst.py b/stanza/stanza/utils/datasets/sentiment/process_sst.py new file mode 100644 index 0000000000000000000000000000000000000000..a672fc3177025d60553ccf517290562963998fae --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_sst.py @@ -0,0 +1,84 @@ +import argparse +import os +import subprocess + +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +import stanza.utils.default_paths as default_paths + +TREEBANK_FILES = ["train.txt", "dev.txt", "test.txt", "extra-train.txt", "checked-extra-train.txt"] + +ARGUMENTS = { + "fiveclass": [], + "root": ["-root_only"], + "binary": ["-ignore_labels", "2", "-remap_labels", "1=0,2=-1,3=1,4=1"], + "binaryroot": ["-root_only", "-ignore_labels", "2", "-remap_labels", "1=0,2=-1,3=1,4=1"], + "threeclass": ["-remap_labels", "0=0,1=0,2=1,3=2,4=2"], + "threeclassroot": ["-root_only", "-remap_labels", "0=0,1=0,2=1,3=2,4=2"], +} + + +def get_subtrees(input_file, *args): + """ + Use the CoreNLP OutputSubtrees tool to convert the input file to a bunch of phrases + + Returns a list of the SentimentDatum namedtuple + """ + # TODO: maybe can convert this to use the python tree? + cmd = ["java", "edu.stanford.nlp.trees.OutputSubtrees", "-input", input_file] + if len(args) > 0: + cmd = cmd + list(args) + print (" ".join(cmd)) + results = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8") + lines = results.stdout.split("\n") + lines = [x.strip() for x in lines] + lines = [x for x in lines if x] + lines = [x.split(maxsplit=1) for x in lines] + phrases = [SentimentDatum(x[0], x[1].split()) for x in lines] + return phrases + +def get_phrases(dataset, treebank_file, input_dir): + extra_args = ARGUMENTS[dataset] + + input_file = os.path.join(input_dir, "fiveclass", treebank_file) + if not os.path.exists(input_file): + raise FileNotFoundError(input_file) + phrases = get_subtrees(input_file, *extra_args) + print("Found {} phrases in SST {} {}".format(len(phrases), treebank_file, dataset)) + return phrases + +def convert_version(dataset, treebank_file, input_dir, output_dir): + """ + Convert the fiveclass files to a specific format + + Uses the ARGUMENTS specific for the format wanted + """ + phrases = get_phrases(dataset, treebank_file, input_dir) + output_file = os.path.join(output_dir, "en_sst.%s.%s.json" % (dataset, treebank_file.split(".")[0])) + process_utils.write_list(output_file, phrases) + +def parse_args(): + """ + Actually, the only argument used right now is the formats to convert + """ + parser = argparse.ArgumentParser() + parser.add_argument('sections', type=str, nargs='*', help='Which transformations to use: {}'.format(" ".join(ARGUMENTS.keys()))) + args = parser.parse_args() + if not args.sections: + args.sections = list(ARGUMENTS.keys()) + return args + +def main(): + args = parse_args() + paths = default_paths.get_default_paths() + input_dir = os.path.join(paths["SENTIMENT_BASE"], "sentiment-treebank") + output_dir = paths["SENTIMENT_DATA_DIR"] + + os.makedirs(output_dir, exist_ok=True) + for section in args.sections: + for treebank_file in TREEBANK_FILES: + convert_version(section, treebank_file, input_dir, output_dir) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/sentiment/process_utils.py b/stanza/stanza/utils/datasets/sentiment/process_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6bd29be639580af22df0f95a54b7a952681b758 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_utils.py @@ -0,0 +1,147 @@ +import csv +import glob +import json +import os +import tempfile + +from collections import namedtuple + +from tqdm import tqdm + +import stanza +from stanza.models.classifiers.data import SentimentDatum + +Split = namedtuple('Split', ['filename', 'weight']) + +SHARDS = ("train", "dev", "test") + +def write_list(out_filename, dataset): + """ + Write a list of items to the given output file + + Expected: list(SentimentDatum) + """ + formatted_dataset = [line._asdict() for line in dataset] + # Rather than write the dataset at once, we write one line at a time + # Using `indent` puts each word on a separate line, which is rather noisy, + # but not formatting at all makes one long line out of an entire dataset, + # which is impossible to read + #json.dump(formatted_dataset, fout, indent=2, ensure_ascii=False) + + with open(out_filename, 'w') as fout: + fout.write("[\n") + for idx, line in enumerate(formatted_dataset): + fout.write(" ") + json.dump(line, fout, ensure_ascii=False) + if idx < len(formatted_dataset) - 1: + fout.write(",") + fout.write("\n") + fout.write("]\n") + +def write_dataset(dataset, out_directory, dataset_name): + """ + Write train, dev, test as .json files for a given dataset + + dataset: 3 lists of sentiment tuples + """ + for shard, phrases in zip(SHARDS, dataset): + output_file = os.path.join(out_directory, "%s.%s.json" % (dataset_name, shard)) + write_list(output_file, phrases) + +def write_splits(out_directory, snippets, splits): + """ + Write the given list of items to the split files in the specified output directory + """ + total_weight = sum(split.weight for split in splits) + divs = [] + subtotal = 0.0 + for split in splits: + divs.append(int(len(snippets) * subtotal / total_weight)) + subtotal = subtotal + split.weight + # the last div will be guaranteed to be the full thing - no math used + divs.append(len(snippets)) + + for i, split in enumerate(splits): + filename = os.path.join(out_directory, split.filename) + print("Writing {}:{} to {}".format(divs[i], divs[i+1], filename)) + write_list(filename, snippets[divs[i]:divs[i+1]]) + +def clean_tokenized_tweet(line): + line = list(line) + if len(line) > 3 and line[0] == 'RT' and line[1][0] == '@' and line[2] == ':': + line = line[3:] + elif len(line) > 4 and line[0] == 'RT' and line[1] == '@' and line[3] == ':': + line = line[4:] + elif line[0][0] == '@': + line = line[1:] + for i in range(len(line)): + if line[i][0] == '@' or line[i][0] == '#': + line[i] = line[i][1:] + line = [x for x in line if x and not x.startswith("http:") and not x.startswith("https:")] + return line + +def get_ptb_tokenized_phrases(dataset): + """ + Use the PTB tokenizer to retokenize the phrases + + Not clear which is better, "Nov." or "Nov ." + strictAcronym=true makes it do the latter + tokenizePerLine=true should make it only pay attention to one line at a time + + Phrases will be returned as lists of words rather than one string + """ + with tempfile.TemporaryDirectory() as tempdir: + phrase_filename = os.path.join(tempdir, "phrases.txt") + #phrase_filename = "asdf.txt" + with open(phrase_filename, "w", encoding="utf-8") as fout: + for item in dataset: + # extra newlines are so the tokenizer treats the lines + # as separate sentences + fout.write("%s\n\n\n" % (item.text)) + tok_filename = os.path.join(tempdir, "tokenized.txt") + os.system('java edu.stanford.nlp.process.PTBTokenizer -options "strictAcronym=true,tokenizePerLine=true" -preserveLines %s > %s' % (phrase_filename, tok_filename)) + with open(tok_filename, encoding="utf-8") as fin: + tokenized = fin.readlines() + + tokenized = [x.strip() for x in tokenized] + tokenized = [x for x in tokenized if x] + phrases = [SentimentDatum(x.sentiment, y.split()) for x, y in zip(dataset, tokenized)] + return phrases + +def read_snippets(csv_filename, sentiment_column, text_column, tokenizer_language, mapping, delimiter='\t', quotechar=None, skip_first_line=False, nlp=None, encoding="utf-8"): + """ + Read in a single CSV file and return a list of SentimentDatums + """ + if nlp is None: + nlp = stanza.Pipeline(tokenizer_language, processors='tokenize') + + with open(csv_filename, newline='', encoding=encoding) as fin: + if skip_first_line: + next(fin) + cin = csv.reader(fin, delimiter=delimiter, quotechar=quotechar) + lines = list(cin) + + # Read in the data and parse it + snippets = [] + for idx, line in enumerate(tqdm(lines)): + try: + if isinstance(sentiment_column, int): + sentiment = line[sentiment_column].lower() + else: + sentiment = tuple([line[x] for x in sentiment_column]) + except IndexError as e: + raise IndexError("Columns {} did not exist at line {}: {}".format(sentiment_column, idx, line)) from e + text = line[text_column] + doc = nlp(text.strip()) + + converted_sentiment = mapping.get(sentiment, None) + if converted_sentiment is None: + raise ValueError("Value {} not in mapping at line {} of {}".format(sentiment, idx, csv_filename)) + + text = [] + for sentence in doc.sentences: + text.extend(token.text for token in sentence.tokens) + text = clean_tokenized_tweet(text) + snippets.append(SentimentDatum(converted_sentiment, text)) + return snippets + diff --git a/stanza/stanza/utils/datasets/sentiment/process_vsfc_vietnamese.py b/stanza/stanza/utils/datasets/sentiment/process_vsfc_vietnamese.py new file mode 100644 index 0000000000000000000000000000000000000000..b84eb6e1890b2119cd38091601cb5bf0fcbb1e43 --- /dev/null +++ b/stanza/stanza/utils/datasets/sentiment/process_vsfc_vietnamese.py @@ -0,0 +1,64 @@ +""" +VSFC sentiment dataset is available at + https://drive.google.com/drive/folders/1xclbjHHK58zk2X6iqbvMPS2rcy9y9E0X + +The format is extremely similar to ours - labels are 0,1,2. +Text needs to be tokenized, though. +Also, the files are split into two pieces, labels and text. +""" + +import os +import sys + +from tqdm import tqdm + +import stanza +from stanza.models.classifiers.data import SentimentDatum +import stanza.utils.datasets.sentiment.process_utils as process_utils + +import stanza.utils.default_paths as default_paths + +def combine_columns(in_directory, dataset, nlp): + directory = os.path.join(in_directory, dataset) + + sentiment_file = os.path.join(directory, "sentiments.txt") + with open(sentiment_file) as fin: + sentiment = fin.readlines() + + text_file = os.path.join(directory, "sents.txt") + with open(text_file) as fin: + text = fin.readlines() + + text = [[token.text for sentence in nlp(line.strip()).sentences for token in sentence.tokens] + for line in tqdm(text)] + + phrases = [SentimentDatum(s.strip(), t) for s, t in zip(sentiment, text)] + return phrases + +def main(in_directory, out_directory, short_name): + nlp = stanza.Pipeline('vi', processors='tokenize') + for shard in ("train", "dev", "test"): + phrases = combine_columns(in_directory, shard, nlp) + output_file = os.path.join(out_directory, "%s.%s.json" % (short_name, shard)) + process_utils.write_list(output_file, phrases) + + +if __name__ == '__main__': + paths = default_paths.get_default_paths() + + if len(sys.argv) <= 1: + in_directory = os.path.join(paths['SENTIMENT_BASE'], "vietnamese", "_UIT-VSFC") + else: + in_directory = sys.argv[1] + + if len(sys.argv) <= 2: + out_directory = paths['SENTIMENT_DATA_DIR'] + else: + out_directory = sys.argv[2] + + if len(sys.argv) <= 3: + short_name = 'vi_vsfc' + else: + short_name = sys.argv[3] + + main(in_directory, out_directory, short_name) diff --git a/stanza/stanza/utils/datasets/tokenization/convert_my_alt.py b/stanza/stanza/utils/datasets/tokenization/convert_my_alt.py new file mode 100644 index 0000000000000000000000000000000000000000..a3aea8914e5085467cf11e481e04beb0ff763827 --- /dev/null +++ b/stanza/stanza/utils/datasets/tokenization/convert_my_alt.py @@ -0,0 +1,189 @@ +"""Converts the Myanmar ALT corpus to a tokenizer dataset. + +The ALT corpus is in the form of constituency trees, which basically +means there is no guidance on where the whitespace belongs. However, +in Myanmar writing, whitespace is apparently not actually required +anywhere. The plan will be to make sentences where there is no +whitespace at all, along with a random selection of sentences +where some whitespace is randomly inserted. + +The treebank is available here: + +https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/ + +The following files describe the splits of the data: + +https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-train.txt +https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-dev.txt +https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/URL-test.txt + +and this is the actual treebank: + +https://www2.nict.go.jp/astrec-att/member/mutiyama/ALT/my-alt-190530.zip + +Download each of the files, then unzip the my-alt zip in place. +The expectation is this will produce a file + my-alt-190530/data + +The default expected path to the Myanmar data is + extern_data/constituency/myanmar/my_alt/my-alt-190530/data +""" + +import os +import random + +from stanza.models.constituency.tree_reader import read_trees + +def read_split(input_dir, section): + """ + Reads the split description for train, dev, or test + + Format (at least for the Myanmar section of ALT) is: + one description per line + each line is URL. + we actually don't care about the URL itself + all we want is the number, which we use to split up + the tree file later + + Returns a set of numbers (as strings) + """ + urls = set() + filename = os.path.join(input_dir, "myanmar", "my_alt", "URL-%s.txt" % section) + with open(filename) as fin: + lines = fin.readlines() + for line in lines: + line = line.strip() + if not line or not line.startswith("URL"): + continue + # split into URL.100161 and a bunch of description we don't care about + line = line.split(maxsplit=1) + # get just the number + line = line[0].split(".") + assert len(line) == 2 + assert line[0] == 'URL' + urls.add(line[1]) + return urls + +SPLITS = ("train", "dev", "test") + +def read_dataset_splits(input_dir): + """ + Call read_split for train, dev, and test + + Returns three sets: train, dev, test in order + """ + url_splits = [read_split(input_dir, section) for section in SPLITS] + for url_split, split in zip(url_splits, SPLITS): + print("Split %s has %d files in it" % (split, len(url_split))) + return url_splits + +def read_alt_treebank(constituency_input_dir): + """ + Read the splits, read the trees, and split the trees based on the split descriptions + + Trees in ALT are: + + The tree id will look like + SNT.. + All we care about from this id is the url_id, which we crossreference in the splits + to figure out which split the tree is in. + + The tree itself we don't process much, although we do convert it to a ParseTree + + The result is three lists: train, dev, test trees + """ + train_split, dev_split, test_split = read_dataset_splits(constituency_input_dir) + + datafile = os.path.join(constituency_input_dir, "myanmar", "my_alt", "my-alt-190530", "data") + print("Reading trees from %s" % datafile) + with open(datafile) as fin: + tree_lines = fin.readlines() + + train_trees = [] + dev_trees = [] + test_trees = [] + + for idx, tree_line in enumerate(tree_lines): + tree_line = tree_line.strip() + if not tree_line: + continue + dataset, tree_text = tree_line.split(maxsplit=1) + dataset = dataset.split(".", 2)[1] + + trees = read_trees(tree_text) + if len(trees) != 1: + raise ValueError("Unexpected number of trees in line %d: %d" % (idx, len(trees))) + tree = trees[0] + + if dataset in train_split: + train_trees.append(tree) + elif dataset in dev_split: + dev_trees.append(tree) + elif dataset in test_split: + test_trees.append(tree) + else: + raise ValueError("Could not figure out which split line %d belongs to" % idx) + + return train_trees, dev_trees, test_trees + +def write_sentence(fout, words, spaces): + """ + Write a sentence based on the list of words. + + spaces is a fraction of the words which should randomly have spaces + If 0.0, none of the words will have spaces + This is because the Myanmar language doesn't require spaces, but + spaces always separate words + """ + full_text = "".join(words) + fout.write("# text = %s\n" % full_text) + + for word_idx, word in enumerate(words): + fake_dep = "root" if word_idx == 0 else "dep" + fout.write("%d\t%s\t%s" % ((word_idx+1), word, word)) + fout.write("\t_\t_\t_") + fout.write("\t%d\t%s" % (word_idx, fake_dep)) + fout.write("\t_\t") + if random.random() > spaces: + fout.write("SpaceAfter=No") + else: + fout.write("_") + fout.write("\n") + fout.write("\n") + + +def write_dataset(filename, trees, split): + """ + Write all of the trees to the given filename + """ + count = 0 + with open(filename, "w") as fout: + # TODO: make some fraction have random spaces inserted + for tree in trees: + count = count + 1 + words = tree.leaf_labels() + write_sentence(fout, words, spaces=0.0) + # We include a small number of spaces to teach the model + # that spaces always separate a word + if split == 'train' and random.random() < 0.1: + count = count + 1 + write_sentence(fout, words, spaces=0.05) + print("Wrote %d sentences from %d trees to %s" % (count, len(trees), filename)) + +def convert_my_alt(constituency_input_dir, tokenizer_dir): + """ + Read and then convert the Myanmar ALT treebank + """ + random.seed(1234) + tree_splits = read_alt_treebank(constituency_input_dir) + + output_filenames = [os.path.join(tokenizer_dir, "my_alt.%s.gold.conllu") % split for split in SPLITS] + + for filename, trees, split in zip(output_filenames, tree_splits, SPLITS): + write_dataset(filename, trees, split) + +def main(): + convert_my_alt("extern_data/constituency", "data/tokenize") + +if __name__ == "__main__": + main() diff --git a/stanza/stanza/utils/datasets/tokenization/convert_text_files.py b/stanza/stanza/utils/datasets/tokenization/convert_text_files.py new file mode 100644 index 0000000000000000000000000000000000000000..854eba2818af6cf8c340094d6e21f72d38b090d1 --- /dev/null +++ b/stanza/stanza/utils/datasets/tokenization/convert_text_files.py @@ -0,0 +1,199 @@ +""" +Given a text file and a file with one word per line, convert the text file + +Sentence splits should be represented as blank lines at the end of a sentence. +""" + +import argparse +import os +import random + +from stanza.models.tokenization.utils import match_tokens_with_text +import stanza.utils.datasets.common as common + +def read_tokens_file(token_file): + """ + Returns a list of list of tokens + + Each sentence is a list of tokens + """ + sentences = [] + current_sentence = [] + with open(token_file, encoding="utf-8") as fin: + for line in fin: + line = line.strip() + if not line: + if current_sentence: + sentences.append(current_sentence) + current_sentence = [] + else: + current_sentence.append(line) + if current_sentence: + sentences.append(current_sentence) + + return sentences + +def read_sentences_file(sentence_file): + sentences = [] + with open(sentence_file, encoding="utf-8") as fin: + for line in fin: + line = line.strip() + if not line: + continue + sentences.append(line) + return sentences + +def process_raw_file(text_file, token_file, sentence_file, base_sent_idx=0): + """ + Process a text file separated into a list of tokens using match_tokens_with_text from the tokenizer + + The tokens are one per line in the token_file + The tokens in the token_file must add up to the text_file modulo whitespace. + + Sentences are also one per line in the sentence_file + These must also add up to text_file + + The return format is a list of list of conllu lines representing the sentences. + The only fields set will be the token index, the token text, and possibly SpaceAfter=No + where SpaceAfter=No is true if the next token started with no whitespace in the text file + """ + with open(text_file, encoding="utf-8") as fin: + text = fin.read() + + tokens = read_tokens_file(token_file) + tokens = [[token for sentence in tokens for token in sentence]] + tokens_doc = match_tokens_with_text(tokens, text) + + assert len(tokens_doc.sentences) == 1 + assert len(tokens_doc.sentences[0].tokens) == len(tokens[0]) + + sentences = read_sentences_file(sentence_file) + sentences_doc = match_tokens_with_text([sentences], text) + + assert len(sentences_doc.sentences) == 1 + assert len(sentences_doc.sentences[0].tokens) == len(sentences) + + start_token_idx = 0 + sentences = [] + for sent_idx, sentence in enumerate(sentences_doc.sentences[0].tokens): + tokens = [] + tokens.append("# sent_id = %d" % (base_sent_idx + sent_idx + 1)) + tokens.append("# text = %s" % text[sentence.start_char:sentence.end_char].replace("\n", " ")) + token_idx = 0 + while token_idx + start_token_idx < len(tokens_doc.sentences[0].tokens): + token = tokens_doc.sentences[0].tokens[token_idx + start_token_idx] + if token.start_char >= sentence.end_char: + # have reached the end of this sentence + # continue with the next sentence + start_token_idx += token_idx + break + + if token_idx + start_token_idx == len(tokens_doc.sentences[0].tokens) - 1: + # definitely the end of the document + space_after = True + elif token.end_char == tokens_doc.sentences[0].tokens[token_idx + start_token_idx + 1].start_char: + space_after = False + else: + space_after = True + token = [str(token_idx+1), token.text] + ["_"] * 7 + ["_" if space_after else "SpaceAfter=No"] + assert len(token) == 10, "Token length: %d" % len(token) + token = "\t".join(token) + tokens.append(token) + token_idx += 1 + sentences.append(tokens) + return sentences + +def extract_sentences(dataset_files): + sentences = [] + for text_file, token_file, sentence_file in dataset_files: + print("Extracting sentences from %s and tokens from %s from the text file %s" % (sentence_file, token_file, text_file)) + sentences.extend(process_raw_file(text_file, token_file, sentence_file, len(sentences))) + return sentences + +def split_sentences(sentences, train_split=0.8, dev_split=0.1): + """ + Splits randomly without shuffling + """ + generator = random.Random(1234) + + train = [] + dev = [] + test = [] + for sentence in sentences: + r = generator.random() + if r < train_split: + train.append(sentence) + elif r < train_split + dev_split: + dev.append(sentence) + else: + test.append(sentence) + return (train, dev, test) + +def find_dataset_files(input_path, token_prefix, sentence_prefix): + files = os.listdir(input_path) + print("Found %d files in %s" % (len(files), input_path)) + if len(files) > 0: + if len(files) < 20: + print("Files:", end="\n ") + else: + print("First few files:", end="\n ") + print("\n ".join(files[:20])) + token_files = {} + sentence_files = {} + text_files = [] + for filename in files: + if filename.endswith(".zip"): + continue + if filename.startswith(token_prefix): + short_filename = filename[len(token_prefix):] + if short_filename.startswith("_"): + short_filename = short_filename[1:] + token_files[short_filename] = filename + elif filename.startswith(sentence_prefix): + short_filename = filename[len(sentence_prefix):] + if short_filename.startswith("_"): + short_filename = short_filename[1:] + sentence_files[short_filename] = filename + else: + text_files.append(filename) + dataset_files = [] + for filename in text_files: + if filename not in token_files: + raise FileNotFoundError("When looking in %s, found %s as a text file, but did not find a corresponding tokens file at %s_%s Please give an input directory which has only the text files, tokens files, and sentences files" % (input_path, filename, token_prefix, filename)) + if filename not in sentence_files: + raise FileNotFoundError("When looking in %s, found %s as a text file, but did not find a corresponding sentences file at %s_%s Please give an input directory which has only the text files, tokens files, and sentences files" % (input_path, filename, sentence_prefix, filename)) + text_file = os.path.join(input_path, filename) + token_file = os.path.join(input_path, token_files[filename]) + sentence_file = os.path.join(input_path, sentence_files[filename]) + dataset_files.append((text_file, token_file, sentence_file)) + return dataset_files + +SHARDS = ("train", "dev", "test") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--token_prefix', type=str, default="tkns", help="Prefix for the token files") + parser.add_argument('--sentence_prefix', type=str, default="stns", help="Prefix for the token files") + parser.add_argument('--input_path', type=str, default="extern_data/sindhi/tokenization", help="Where to find all of the input files. Files with the prefix tkns_ will be treated as token files, files with the prefix stns_ will be treated as sentence files, and all others will be the text files.") + parser.add_argument('--output_path', type=str, default="data/tokenize", help="Where to output the results") + parser.add_argument('--dataset', type=str, default="sd_isra", help="What name to give this dataset") + args = parser.parse_args() + + dataset_files = find_dataset_files(args.input_path, args.token_prefix, args.sentence_prefix) + + tokenizer_dir = args.output_path + short_name = args.dataset # todo: convert a full name? + + sentences = extract_sentences(dataset_files) + splits = split_sentences(sentences) + + os.makedirs(args.output_path, exist_ok=True) + for dataset, shard in zip(splits, SHARDS): + output_conllu = common.tokenizer_conllu_name(tokenizer_dir, short_name, shard) + common.write_sentences_to_conllu(output_conllu, dataset) + + common.convert_conllu_to_txt(tokenizer_dir, short_name) + common.prepare_tokenizer_treebank_labels(tokenizer_dir, short_name) + +if __name__ == '__main__': + main() diff --git a/stanza/stanza/utils/datasets/tokenization/convert_th_best.py b/stanza/stanza/utils/datasets/tokenization/convert_th_best.py new file mode 100644 index 0000000000000000000000000000000000000000..46a005c4f56695e216878dea878467276273c95d --- /dev/null +++ b/stanza/stanza/utils/datasets/tokenization/convert_th_best.py @@ -0,0 +1,171 @@ +"""Parses the BEST Thai dataset. + +That is to say, the dataset named BEST. We have not yet figured out +which segmentation standard we prefer. + +Note that the version of BEST we used actually had some strange +sentence splits according to a native Thai speaker. Not sure how to +fix that. Options include doing it automatically or finding some +knowledgable annotators to resplit it for us (or just not using BEST) + +This outputs the tokenization results in a conll format similar to +that of the UD treebanks, so we pretend to be a UD treebank for ease +of compatibility with the stanza tools. + +BEST can be downloaded from here: + +https://aiforthai.in.th/corpus.php + +python3 -m stanza.utils.datasets.tokenization.process_best extern_data/thai/best data/tokenize +./scripts/run_tokenize.sh UD_Thai-best --dropout 0.05 --unit_dropout 0.05 --steps 50000 +""" +import glob +import os +import random +import re +import sys + +try: + from pythainlp import sent_tokenize +except ImportError: + pass + +from stanza.utils.datasets.tokenization.process_thai_tokenization import reprocess_lines, write_dataset, convert_processed_lines, write_dataset_best, write_dataset + +def clean_line(line): + line = line.replace("html>", "html|>") + # news_00089.txt + line = line.replace("", "") + line = line.replace("", "") + # specific error that occurs in encyclopedia_00095.txt + line = line.replace("Penn", "|Penn>") + # news_00058.txt + line = line.replace("จม.เปิดผนึก", "จม.|เปิดผนึก") + # news_00015.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + # news_00024.txt + line = re.sub("([^|<>]+)", "\\1", line) + # news_00055.txt + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+)", "\\1|\\2", line) + line = re.sub("([^|<>]+)([^|<>]+) ([^|<>]+)", "\\1|\\2|\\3", line) + # news_00008.txt and other news articles + line = re.sub("([0-9])", "|\\1", line) + line = line.replace(" ", "|") + line = line.replace("", "") + line = line.replace("", "") + line = line.strip() + return line + + +def clean_word(word): + # novel_00078.txt + if word == '': + return 'พี่มน' + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[4:-5] + if word.startswith("") and word.endswith(""): + return word[6:-7] + """ + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] + """ + if word.startswith(""): + return word[4:] + if word.endswith(""): + return word[:-5] + if word.startswith(""): + return word[6:] + if word.endswith(""): + return word[:-7] + if word == '<': + return word + return word + +def read_data(input_dir): + # data for test sets + test_files = [os.path.join(input_dir, 'TEST_100K_ANS.txt')] + print(test_files) + + # data for train and dev sets + subdirs = [os.path.join(input_dir, 'article'), + os.path.join(input_dir, 'encyclopedia'), + os.path.join(input_dir, 'news'), + os.path.join(input_dir, 'novel')] + files = [] + for subdir in subdirs: + if not os.path.exists(subdir): + raise FileNotFoundError("Expected a directory that did not exist: {}".format(subdir)) + files.extend(glob.glob(os.path.join(subdir, '*.txt'))) + + test_documents = [] + for filename in test_files: + print("File name:", filename) + with open(filename) as fin: + processed_lines = [] + for line in fin.readlines(): + line = clean_line(line) + words = line.split("|") + words = [clean_word(x) for x in words] + for word in words: + if len(word) > 1 and word[0] == '<': + raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) + words = [x for x in words if x] + processed_lines.append(words) + + processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) + + test_documents.extend(paragraphs) + print("Test document finished.") + + documents = [] + + for filename in files: + with open(filename) as fin: + print("File:", filename) + processed_lines = [] + for line in fin.readlines(): + line = clean_line(line) + words = line.split("|") + words = [clean_word(x) for x in words] + for word in words: + if len(word) > 1 and word[0] == '<': + raise ValueError("Unexpected word '{}' in document {}".format(word, filename)) + words = [x for x in words if x] + processed_lines.append(words) + + processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) + + documents.extend(paragraphs) + + print("All documents finished.") + + return documents, test_documents + + +def main(*args): + random.seed(1000) + if not args: + args = sys.argv[1:] + + input_dir = args[0] + full_input_dir = os.path.join(input_dir, "thai", "best") + if os.path.exists(full_input_dir): + # otherwise hopefully the user gave us the full path? + input_dir = full_input_dir + + output_dir = args[1] + documents, test_documents = read_data(input_dir) + print("Finished reading data.") + write_dataset_best(documents, test_documents, output_dir, "best") + + +if __name__ == '__main__': + main() + diff --git a/stanza/stanza/utils/datasets/tokenization/convert_th_lst20.py b/stanza/stanza/utils/datasets/tokenization/convert_th_lst20.py new file mode 100644 index 0000000000000000000000000000000000000000..744c44cdd3c16b81fce7b5e80abeef35b100414a --- /dev/null +++ b/stanza/stanza/utils/datasets/tokenization/convert_th_lst20.py @@ -0,0 +1,131 @@ +"""Processes the tokenization section of the LST20 Thai dataset + +The dataset is available here: + +https://aiforthai.in.th/corpus.php + +The data should be installed under ${EXTERN_DATA}/thai/LST20_Corpus + +python3 -m stanza.utils.datasets.tokenization.convert_th_lst20 extern_data data/tokenize + +Unlike Orchid and BEST, LST20 has train/eval/test splits, which we relabel train/dev/test. + +./scripts/run_tokenize.sh UD_Thai-lst20 --dropout 0.05 --unit_dropout 0.05 +""" + + +import argparse +import glob +import os +import sys + +from stanza.utils.datasets.tokenization.process_thai_tokenization import write_section, convert_processed_lines, reprocess_lines + +def read_document(lines, spaces_after, split_clauses): + document = [] + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + if spaces_after: + sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + # there are some nbsp in tokens in lst20, but the downstream tools expect spaces + pieces = [p.replace("\xa0", " ") for p in pieces] + if split_clauses and pieces[0] == '_' and pieces[3] == 'O': + if sentence: + # note that we don't need to check spaces_after + # the "token" is a space anyway + sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + elif pieces[0] == '_': + sentence[-1] = (sentence[-1][0], True) + else: + sentence.append((pieces[0], False)) + + if sentence: + if spaces_after: + sentence[-1] = (sentence[-1][0], True) + document.append(sentence) + sentence = [] + # TODO: is there any way to divide up a single document into paragraphs? + return [[document]] + +def retokenize_document(lines): + processed_lines = [] + sentence = [] + for line in lines: + line = line.strip() + if not line: + if sentence: + processed_lines.append(sentence) + sentence = [] + else: + pieces = line.split("\t") + if pieces[0] == '_': + sentence.append(' ') + else: + sentence.append(pieces[0]) + if sentence: + processed_lines.append(sentence) + + processed_lines = reprocess_lines(processed_lines) + paragraphs = convert_processed_lines(processed_lines) + return paragraphs + + +def read_data(input_dir, section, resegment, spaces_after, split_clauses): + glob_path = os.path.join(input_dir, section, "*.txt") + filenames = glob.glob(glob_path) + print(" Found {} files in {}".format(len(filenames), glob_path)) + if len(filenames) == 0: + raise FileNotFoundError("Could not find any files for the {} section. Is LST20 installed in {}?".format(section, input_dir)) + documents = [] + for filename in filenames: + with open(filename) as fin: + lines = fin.readlines() + if resegment: + document = retokenize_document(lines) + else: + document = read_document(lines, spaces_after, split_clauses) + documents.extend(document) + return documents + +def add_lst20_args(parser): + parser.add_argument('--no_lst20_resegment', action='store_false', dest="lst20_resegment", default=True, help='When processing th_lst20 tokenization, use pythainlp to resegment the text. The other option is to keep the original sentence segmentation. Currently our model is not good at that') + parser.add_argument('--lst20_spaces_after', action='store_true', dest="lst20_spaces_after", default=False, help='When processing th_lst20 without pythainlp, put spaces after each sentence. This better fits the language but gets lower scores for some reason') + parser.add_argument('--split_clauses', action='store_true', dest="split_clauses", default=False, help='When processing th_lst20 without pythainlp, turn spaces which are labeled as between clauses into sentence splits') + +def parse_lst20_args(): + parser = argparse.ArgumentParser() + parser.add_argument('input_dir', help="Directory to use when processing lst20") + parser.add_argument('output_dir', help="Directory to use when saving lst20") + add_lst20_args(parser) + return parser.parse_args() + + + +def convert(input_dir, output_dir, args): + input_dir = os.path.join(input_dir, "thai", "LST20_Corpus") + if not os.path.exists(input_dir): + raise FileNotFoundError("Could not find LST20 corpus in {}".format(input_dir)) + + for (in_section, out_section) in (("train", "train"), + ("eval", "dev"), + ("test", "test")): + print("Processing %s" % out_section) + documents = read_data(input_dir, in_section, args.lst20_resegment, args.lst20_spaces_after, args.split_clauses) + print(" Read in %d documents" % len(documents)) + write_section(output_dir, "lst20", out_section, documents) + +def main(): + args = parse_lst20_args() + convert(args.input_dir, args.output_dir, args) + +if __name__ == '__main__': + main()