|
|
""" |
|
|
Collects a few of the conparser utility methods which don't belong elsewhere |
|
|
""" |
|
|
|
|
|
from collections import Counter |
|
|
import logging |
|
|
import warnings |
|
|
|
|
|
import torch.nn as nn |
|
|
from torch import optim |
|
|
|
|
|
from stanza.models.common.doc import TEXT, Document |
|
|
from stanza.models.common.utils import get_optimizer |
|
|
from stanza.models.constituency.base_model import SimpleModel |
|
|
from stanza.models.constituency.parse_transitions import TransitionScheme |
|
|
from stanza.models.constituency.parse_tree import Tree |
|
|
from stanza.utils.get_tqdm import get_tqdm |
|
|
|
|
|
tqdm = get_tqdm() |
|
|
|
|
|
DEFAULT_LEARNING_RATES = { "adamw": 0.0002, "adadelta": 1.0, "sgd": 0.001, "adabelief": 0.00005, "madgrad": 0.0000007 , "mirror_madgrad": 0.00005 } |
|
|
DEFAULT_LEARNING_EPS = { "adabelief": 1e-12, "adadelta": 1e-6, "adamw": 1e-8 } |
|
|
DEFAULT_LEARNING_RHO = 0.9 |
|
|
DEFAULT_MOMENTUM = { "madgrad": 0.9, "mirror_madgrad": 0.9, "sgd": 0.9 } |
|
|
|
|
|
tlogger = logging.getLogger('stanza.constituency.trainer') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_WEIGHT_DECAY = { "adamw": 0.05, "adadelta": 0.02, "sgd": 0.01, "adabelief": 1.2e-6, "madgrad": 2e-6, "mirror_madgrad": 2e-6 } |
|
|
|
|
|
def retag_tags(doc, pipelines, xpos): |
|
|
""" |
|
|
Returns a list of list of tags for the items in doc |
|
|
|
|
|
doc can be anything which feeds into the pipeline(s) |
|
|
pipelines are a list of 1 or more retag pipelines |
|
|
if multiple pipelines are given, majority vote wins |
|
|
""" |
|
|
tag_lists = [] |
|
|
for pipeline in pipelines: |
|
|
doc = pipeline(doc) |
|
|
tag_lists.append([[x.xpos if xpos else x.upos for x in sentence.words] for sentence in doc.sentences]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tag_lists = [[Counter(tag).most_common(1)[0][0] for tag in zip(*sentence)] |
|
|
for sentence in zip(*tag_lists)] |
|
|
return tag_lists |
|
|
|
|
|
def retag_trees(trees, pipelines, xpos=True): |
|
|
""" |
|
|
Retag all of the trees using the given processor |
|
|
|
|
|
Returns a list of new trees |
|
|
""" |
|
|
if len(trees) == 0: |
|
|
return trees |
|
|
|
|
|
new_trees = [] |
|
|
chunk_size = 1000 |
|
|
with tqdm(total=len(trees)) as pbar: |
|
|
for chunk_start in range(0, len(trees), chunk_size): |
|
|
chunk_end = min(chunk_start + chunk_size, len(trees)) |
|
|
chunk = trees[chunk_start:chunk_end] |
|
|
sentences = [] |
|
|
try: |
|
|
for idx, tree in enumerate(chunk): |
|
|
tokens = [{TEXT: pt.children[0].label} for pt in tree.yield_preterminals()] |
|
|
sentences.append(tokens) |
|
|
except ValueError as e: |
|
|
raise ValueError("Unable to process tree %d" % (idx + chunk_start)) from e |
|
|
|
|
|
doc = Document(sentences) |
|
|
tag_lists = retag_tags(doc, pipelines, xpos) |
|
|
|
|
|
for tree_idx, (tree, tags) in enumerate(zip(chunk, tag_lists)): |
|
|
try: |
|
|
if any(tag is None for tag in tags): |
|
|
raise RuntimeError("Tagged tree #{} with a None tag!\n{}\n{}".format(tree_idx, tree, tags)) |
|
|
new_tree = tree.replace_tags(tags) |
|
|
new_trees.append(new_tree) |
|
|
pbar.update(1) |
|
|
except ValueError as e: |
|
|
raise ValueError("Failed to properly retag tree #{}: {}".format(tree_idx, tree)) from e |
|
|
if len(new_trees) != len(trees): |
|
|
raise AssertionError("Retagged tree counts did not match: {} vs {}".format(len(new_trees), len(trees))) |
|
|
return new_trees |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NONLINEARITY = { |
|
|
'celu': nn.CELU, |
|
|
'elu': nn.ELU, |
|
|
'gelu': nn.GELU, |
|
|
'hardshrink': nn.Hardshrink, |
|
|
'hardtanh': nn.Hardtanh, |
|
|
'leaky_relu': nn.LeakyReLU, |
|
|
'logsigmoid': nn.LogSigmoid, |
|
|
'prelu': nn.PReLU, |
|
|
'relu': nn.ReLU, |
|
|
'relu6': nn.ReLU6, |
|
|
'rrelu': nn.RReLU, |
|
|
'selu': nn.SELU, |
|
|
'softplus': nn.Softplus, |
|
|
'softshrink': nn.Softshrink, |
|
|
'softsign': nn.Softsign, |
|
|
'tanhshrink': nn.Tanhshrink, |
|
|
'tanh': nn.Tanh, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nonlinearity_list = [ |
|
|
'GLU', |
|
|
'Hardsigmoid', |
|
|
'Hardswish', |
|
|
'Mish', |
|
|
'SiLU', |
|
|
] |
|
|
|
|
|
for nonlinearity in nonlinearity_list: |
|
|
if hasattr(nn, nonlinearity): |
|
|
NONLINEARITY[nonlinearity.lower()] = getattr(nn, nonlinearity) |
|
|
|
|
|
def build_nonlinearity(nonlinearity): |
|
|
""" |
|
|
Look up "nonlinearity" in a map from function name to function, build the appropriate layer. |
|
|
""" |
|
|
if nonlinearity in NONLINEARITY: |
|
|
return NONLINEARITY[nonlinearity]() |
|
|
raise ValueError('Chosen value of nonlinearity, "%s", not handled' % nonlinearity) |
|
|
|
|
|
def build_optimizer(args, model, build_simple_adadelta=False): |
|
|
""" |
|
|
Build an optimizer based on the arguments given |
|
|
|
|
|
If we are "multistage" training and epochs_trained < epochs // 2, |
|
|
we build an AdaDelta optimizer instead of whatever was requested |
|
|
The build_simple_adadelta parameter controls this |
|
|
""" |
|
|
bert_learning_rate = 0.0 |
|
|
bert_weight_decay = args['bert_weight_decay'] |
|
|
if build_simple_adadelta: |
|
|
optim_type = 'adadelta' |
|
|
bert_finetune = args.get('stage1_bert_finetune', False) |
|
|
if bert_finetune: |
|
|
bert_learning_rate = args['stage1_bert_learning_rate'] |
|
|
learning_beta2 = 0.999 |
|
|
learning_eps = DEFAULT_LEARNING_EPS['adadelta'] |
|
|
learning_rate = args['stage1_learning_rate'] |
|
|
learning_rho = DEFAULT_LEARNING_RHO |
|
|
momentum = None |
|
|
weight_decay = DEFAULT_WEIGHT_DECAY['adadelta'] |
|
|
else: |
|
|
optim_type = args['optim'].lower() |
|
|
bert_finetune = args.get('bert_finetune', False) |
|
|
if bert_finetune: |
|
|
bert_learning_rate = args['bert_learning_rate'] |
|
|
learning_beta2 = args['learning_beta2'] |
|
|
learning_eps = args['learning_eps'] |
|
|
learning_rate = args['learning_rate'] |
|
|
learning_rho = args['learning_rho'] |
|
|
momentum = args['learning_momentum'] |
|
|
weight_decay = args['learning_weight_decay'] |
|
|
|
|
|
|
|
|
return get_optimizer(name=optim_type, |
|
|
model=model, |
|
|
lr=learning_rate, |
|
|
betas=(0.9, learning_beta2), |
|
|
eps=learning_eps, |
|
|
momentum=momentum, |
|
|
weight_decay=weight_decay, |
|
|
bert_learning_rate=bert_learning_rate, |
|
|
bert_weight_decay=weight_decay*bert_weight_decay, |
|
|
is_peft=args.get('use_peft', False), |
|
|
bert_finetune_layers=args['bert_finetune_layers'], |
|
|
opt_logger=tlogger) |
|
|
|
|
|
def build_scheduler(args, optimizer, first_optimizer=False): |
|
|
""" |
|
|
Build the scheduler for the conparser based on its args |
|
|
|
|
|
Used to use a warmup for learning rate, but that wasn't working very well |
|
|
Now, we just use a ReduceLROnPlateau, which does quite well |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if first_optimizer: |
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['stage1_learning_rate_min_lr']) |
|
|
else: |
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args['learning_rate_factor'], patience=args['learning_rate_patience'], cooldown=args['learning_rate_cooldown'], min_lr=args['learning_rate_min_lr']) |
|
|
return scheduler |
|
|
|
|
|
def initialize_linear(linear, nonlinearity, bias): |
|
|
""" |
|
|
Initializes the bias to a positive value, hopefully preventing dead neurons |
|
|
""" |
|
|
if nonlinearity in ('relu', 'leaky_relu'): |
|
|
nn.init.kaiming_normal_(linear.weight, nonlinearity=nonlinearity) |
|
|
nn.init.uniform_(linear.bias, 0, 1 / (bias * 2) ** 0.5) |
|
|
|
|
|
def add_predict_output_args(parser): |
|
|
""" |
|
|
Args specifically for the output location of data |
|
|
""" |
|
|
parser.add_argument('--predict_dir', type=str, default=".", help='Where to write the predictions during --mode predict. Pred and orig files will be written - the orig file will be retagged if that is requested. Writing the orig file is useful for removing None and retagging') |
|
|
parser.add_argument('--predict_file', type=str, default=None, help='Base name for writing predictions') |
|
|
parser.add_argument('--predict_format', type=str, default="{:_O}", help='Format to use when writing predictions') |
|
|
|
|
|
parser.add_argument('--predict_output_gold_tags', default=False, action='store_true', help='Output gold tags as part of the evaluation - useful for putting the trees through EvalB') |
|
|
|
|
|
def postprocess_predict_output_args(args): |
|
|
if len(args['predict_format']) <= 2 or (len(args['predict_format']) <= 3 and args['predict_format'].endswith("Vi")): |
|
|
args['predict_format'] = "{:" + args['predict_format'] + "}" |
|
|
|
|
|
|
|
|
def get_open_nodes(trees, transition_scheme): |
|
|
""" |
|
|
Return a list of all open nodes in the given dataset. |
|
|
Depending on the parameters, may be single or compound open transitions. |
|
|
""" |
|
|
if transition_scheme is TransitionScheme.TOP_DOWN_COMPOUND: |
|
|
return Tree.get_compound_constituents(trees) |
|
|
elif transition_scheme is TransitionScheme.IN_ORDER_COMPOUND: |
|
|
return Tree.get_compound_constituents(trees, separate_root=True) |
|
|
else: |
|
|
return [(x,) for x in Tree.get_unique_constituent_labels(trees)] |
|
|
|
|
|
|
|
|
def verify_transitions(trees, sequences, transition_scheme, unary_limit, reverse, name, root_labels): |
|
|
""" |
|
|
Given a list of trees and their transition sequences, verify that the sequences rebuild the trees |
|
|
""" |
|
|
model = SimpleModel(transition_scheme, unary_limit, reverse, root_labels) |
|
|
tlogger.info("Verifying the transition sequences for %d trees", len(trees)) |
|
|
|
|
|
data = zip(trees, sequences) |
|
|
if tlogger.getEffectiveLevel() <= logging.INFO: |
|
|
data = tqdm(zip(trees, sequences), total=len(trees)) |
|
|
|
|
|
for tree_idx, (tree, sequence) in enumerate(data): |
|
|
|
|
|
state = model.initial_state_from_gold_trees([tree])[0] |
|
|
for idx, trans in enumerate(sequence): |
|
|
if not trans.is_legal(state, model): |
|
|
raise RuntimeError("Tree {} of {} failed: transition {}:{} was not legal in a transition sequence:\nOriginal tree: {}\nTransitions: {}".format(tree_idx, name, idx, trans, tree, sequence)) |
|
|
state = trans.apply(state, model) |
|
|
result = model.get_top_constituent(state.constituents) |
|
|
if reverse: |
|
|
result = result.reverse() |
|
|
if tree != result: |
|
|
raise RuntimeError("Tree {} of {} failed: transition sequence did not match for a tree!\nOriginal tree:{}\nTransitions: {}\nResult tree:{}".format(tree_idx, name, tree, sequence, result)) |
|
|
|
|
|
def check_constituents(train_constituents, trees, treebank_name, fail=True): |
|
|
""" |
|
|
Check that all the constituents in the other dataset are known in the train set |
|
|
""" |
|
|
constituents = Tree.get_unique_constituent_labels(trees) |
|
|
for con in constituents: |
|
|
if con not in train_constituents: |
|
|
first_error = None |
|
|
num_errors = 0 |
|
|
for tree_idx, tree in enumerate(trees): |
|
|
constituents = Tree.get_unique_constituent_labels(tree) |
|
|
if con in constituents: |
|
|
num_errors += 1 |
|
|
if first_error is None: |
|
|
first_error = tree_idx |
|
|
error = "Found constituent label {} in the {} set which don't exist in the train set. This constituent label occured in {} trees, with the first tree index at {} counting from 1\nThe error tree (which may have POS tags changed from the retagger and may be missing functional tags or empty nodes) is:\n{:P}".format(con, treebank_name, num_errors, (first_error+1), trees[first_error]) |
|
|
if fail: |
|
|
raise RuntimeError(error) |
|
|
else: |
|
|
warnings.warn(error) |
|
|
|
|
|
def check_root_labels(root_labels, other_trees, treebank_name): |
|
|
""" |
|
|
Check that all the root states in the other dataset are known in the train set |
|
|
""" |
|
|
for root_state in Tree.get_root_labels(other_trees): |
|
|
if root_state not in root_labels: |
|
|
raise RuntimeError("Found root state {} in the {} set which is not a ROOT state in the train set".format(root_state, treebank_name)) |
|
|
|
|
|
def remove_duplicate_trees(trees, treebank_name): |
|
|
""" |
|
|
Filter duplicates from the given dataset |
|
|
""" |
|
|
new_trees = [] |
|
|
known_trees = set() |
|
|
for tree in trees: |
|
|
tree_str = "{}".format(tree) |
|
|
if tree_str in known_trees: |
|
|
continue |
|
|
known_trees.add(tree_str) |
|
|
new_trees.append(tree) |
|
|
if len(new_trees) < len(trees): |
|
|
tlogger.info("Filtered %d duplicates from %s dataset", (len(trees) - len(new_trees)), treebank_name) |
|
|
return new_trees |
|
|
|
|
|
def remove_singleton_trees(trees): |
|
|
""" |
|
|
remove trees which are just a root and a single word |
|
|
|
|
|
TODO: remove these trees in the conversion instead of here |
|
|
""" |
|
|
new_trees = [x for x in trees if |
|
|
len(x.children) > 1 or |
|
|
(len(x.children) == 1 and len(x.children[0].children) > 1) or |
|
|
(len(x.children) == 1 and len(x.children[0].children) == 1 and len(x.children[0].children[0].children) >= 1)] |
|
|
if len(trees) - len(new_trees) > 0: |
|
|
tlogger.info("Eliminated %d trees with missing structure", (len(trees) - len(new_trees))) |
|
|
return new_trees |
|
|
|
|
|
|