#!/usr/bin/env python3 # # This file is part of LatinPipe EvaLatin 24 # . # # Copyright 2024 Institute of Formal and Applied Linguistics, Faculty of # Mathematics and Physics, Charles University in Prague, Czech Republic. # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import argparse import collections import datetime import difflib import io import json import os import pickle import re from typing import Self os.environ.setdefault("KERAS_BACKEND", "torch") import keras import numpy as np import torch import transformers import ufal.chu_liu_edmonds import latinpipe_evalatin24_eval parser = argparse.ArgumentParser() parser.add_argument("--batch_size", default=32, type=int, help="Batch size.") parser.add_argument("--deprel", default="full", choices=["full", "universal"], type=str, help="Deprel kind.") parser.add_argument("--dev", default=[], nargs="+", type=str, help="Dev CoNLL-U files.") parser.add_argument("--dropout", default=0.5, type=float, help="Dropout") parser.add_argument("--embed_tags", default="", type=str, help="Tags to embed on input.") parser.add_argument("--epochs", default=30, type=int, help="Number of epochs.") parser.add_argument("--epochs_frozen", default=0, type=int, help="Number of epochs with frozen transformer.") parser.add_argument("--exp", default=None, type=str, help="Experiment name.") parser.add_argument("--label_smoothing", default=0.03, type=float, help="Label smoothing.") parser.add_argument("--learning_rate", default=2e-5, type=float, help="Learning rate.") parser.add_argument("--learning_rate_decay", default="cos", choices=["none", "cos"], type=str, help="Learning rate decay.") parser.add_argument("--learning_rate_warmup", default=2_000, type=int, help="Number of warmup steps.") parser.add_argument("--load", default=[], type=str, nargs="*", help="Path to load models from.") parser.add_argument("--max_train_sentence_len", default=150, type=int, help="Max sentence subwords in training.") parser.add_argument("--optimizer", default="adam", choices=["adam", "adafactor"], type=str, help="Optimizer.") parser.add_argument("--parse", default=1, type=int, help="Parse.") parser.add_argument("--parse_attention_dim", default=512, type=int, help="Parse attention dimension.") parser.add_argument("--rnn_dim", default=512, type=int, help="RNN layers size.") parser.add_argument("--rnn_layers", default=2, type=int, help="RNN layers.") parser.add_argument("--rnn_type", default="LSTMTorch", choices=["LSTM", "GRU", "LSTMTorch", "GRUTorch"], help="RNN type.") parser.add_argument("--save_checkpoint", default=False, action="store_true", help="Save checkpoint.") parser.add_argument("--seed", default=42, type=int, help="Initial random seed.") parser.add_argument("--steps_per_epoch", default=1_000, type=int, help="Steps per epoch.") parser.add_argument("--single_root", default=1, type=int, help="Single root allowed only.") parser.add_argument("--subword_combination", default="first", choices=["first", "last", "sum", "concat"], type=str, help="Subword combination.") parser.add_argument("--tags", default="UPOS,LEMMAS,FEATS", type=str, help="Tags to predict.") parser.add_argument("--task_hidden_layer", default=2_048, type=int, help="Task hidden layer size.") parser.add_argument("--test", default=[], nargs="+", type=str, help="Test CoNLL-U files.") parser.add_argument("--train", default=[], nargs="+", type=str, help="Train CoNLL-U files.") parser.add_argument("--train_sampling_exponent", default=0.5, type=float, help="Train sampling exponent.") parser.add_argument("--transformers", nargs="+", type=str, help="Transformers models to use.") parser.add_argument("--treebank_ids", default=False, action="store_true", help="Include treebank IDs on input.") parser.add_argument("--threads", default=4, type=int, help="Maximum number of threads to use.") parser.add_argument("--verbose", default=2, type=int, help="Verbosity") parser.add_argument("--wandb", default=False, action="store_true", help="Log in WandB.") parser.add_argument("--word_masking", default=None, type=float, help="Word masking") class UDDataset: FORMS, LEMMAS, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC, FACTORS = range(10) FACTORS_MAP = {"FORMS": FORMS, "LEMMAS": LEMMAS, "UPOS": UPOS, "XPOS": XPOS, "FEATS": FEATS, "HEAD": HEAD, "DEPREL": DEPREL, "DEPS": DEPS, "MISC": MISC} RE_EXTRAS = re.compile(r"^#|^\d+-|^\d+\.") class Factor: def __init__(self, train_factor: Self = None): self.words_map = train_factor.words_map if train_factor else {"": 0} self.words = train_factor.words if train_factor else [""] self.word_ids = [] self.strings = [] def __init__(self, path: str, args: argparse.Namespace, treebank_id: int|None = None, train_dataset: Self = None, text: str|None = None): self.path = path # Create factors and other variables self.factors = [] for f in range(self.FACTORS): self.factors.append(self.Factor(train_dataset.factors[f] if train_dataset is not None else None)) self._extras = [] lemma_transforms = collections.Counter() # Load the CoNLL-U file with open(path, "r", encoding="utf-8") if text is None else io.StringIO(text) as file: in_sentence = False for line in file: line = line.rstrip("\r\n") if line: if self.RE_EXTRAS.match(line): if in_sentence: while len(self._extras) < len(self.factors[0].strings): self._extras.append([]) while len(self._extras[-1]) <= len(self.factors[0].strings[-1]): self._extras[-1].append("") else: while len(self._extras) <= len(self.factors[0].strings): self._extras.append([]) if not len(self._extras[-1]): self._extras[-1].append("") self._extras[-1][-1] += ("\n" if self._extras[-1][-1] else "") + line continue columns = line.split("\t")[1:] for f in range(self.FACTORS): factor = self.factors[f] if not in_sentence: factor.word_ids.append([]) factor.strings.append([]) word = columns[f] factor.strings[-1].append(word) # Add word to word_ids if f == self.FORMS: # For formw, we do not remap strings into IDs because the tokenizer will create the subwords IDs for us. factor.word_ids[-1].append(0) elif f == self.HEAD: factor.word_ids[-1].append(int(word) if word != "_" else -1) elif f == self.LEMMAS: factor.word_ids[-1].append(0) lemma_transforms[(columns[self.FORMS], word)] += 1 else: if f == self.DEPREL and args.deprel == "universal": word = word.split(":")[0] if word not in factor.words_map: if train_dataset is not None: word = "" else: factor.words_map[word] = len(factor.words) factor.words.append(word) factor.word_ids[-1].append(factor.words_map[word]) in_sentence = True else: in_sentence = False for factor in self.factors: if len(factor.word_ids): factor.word_ids[-1] = np.array(factor.word_ids[-1], np.int32) # Also load the file for evaluation if it is not a training dataset if train_dataset is not None: file.seek(0, io.SEEK_SET) self.conllu_for_eval = latinpipe_evalatin24_eval.load_conllu(file) # Construct lemma rules self.finalize_lemma_rules(lemma_transforms, create_rules=train_dataset is None) # The dataset consists of a single treebank self.treebank_ranges = [(0, len(self))] self.treebank_ids = [treebank_id] # Create an empty tokenize cache self._tokenizer_cache = {} def __len__(self): return len(self.factors[0].strings) def save_mappings(self, path: str) -> None: mappings = UDDataset.__new__(UDDataset) mappings.factors = [] for factor in self.factors: mappings.factors.append(UDDataset.Factor.__new__(UDDataset.Factor)) mappings.factors[-1].words = factor.words with open(path, "wb") as mappings_file: pickle.dump(mappings, mappings_file, protocol=4) @staticmethod def from_mappings(path: str) -> Self: with open(path, "rb") as mappings_file: mappings = pickle.load(mappings_file) for factor in mappings.factors: factor.words_map = {word: i for i, word in enumerate(factor.words)} return mappings @staticmethod def create_lemma_rule(form: str, lemma: str) -> str: diff = difflib.SequenceMatcher(None, form.lower(), lemma.lower(), False) rule, in_prefix = [], True for tag, i1, i2, j1, j2 in diff.get_opcodes(): if i2 > len(form) // 3 and in_prefix: in_prefix = False if tag == "equal": mode, jd = "L" if lemma[j2 - 1].islower() else "U", j2 - 1 while jd > j1 and lemma[jd - 1].islower() == lemma[j2 - 1].islower(): jd -= 1 rule.extend(["l" if lemma[j].islower() else "u" for j in range(j1, jd)]) rule.extend(mode * (len(form) - i2 + 1)) if tag in ["replace", "delete"]: rule.extend("D" * (len(form) - i2 + 1)) if tag in ["replace", "insert"]: rule.extend("i" + lemma[j] for j in range(j1, j2)) else: if tag == "equal": rule.extend(["l" if lemma[j].islower() else "u" for j in range(j1, j2)]) if tag in ["replace", "delete"]: rule.extend("d" * (i2 - i1)) if tag in ["replace", "insert"]: rule.extend("i" + lemma[j] for j in range(j1, j2)) return "".join(rule) @staticmethod def apply_lemma_rule(rule: str, form: str) -> str: def error(): # print("Error: cannot decode lemma rule '{}' with form '{}', copying input.".format(rule, form)) return form if rule == "": return form lemma, r, i = [], 0, 0 while r < len(rule): if rule[r] == "i": if r + 1 == len(rule): return error() r += 1 lemma.append(rule[r]) elif rule[r] == "d": i += 1 elif rule[r] in ("l", "u"): if i == len(form): return error() lemma.append(form[i].lower() if rule[r] == "l" else form[i].upper()) i += 1 elif rule[r] in ("L", "U", "D"): i2 = len(form) while r + 1 < len(rule) and rule[r + 1] == rule[r]: r += 1 i2 -= 1 if i2 < i: return error() if rule[r] == "L": lemma.extend(form[i:i2].lower()) if rule[r] == "U": lemma.extend(form[i:i2].upper()) i = i2 else: return error() r += 1 if i != len(form) or not lemma: return error() return "".join(lemma) def finalize_lemma_rules(self, lemma_transforms: collections.Counter, create_rules: bool) -> None: forms, lemmas = self.factors[self.FORMS], self.factors[self.LEMMAS] # Generate all rules rules_merged, rules_all = collections.Counter(), {} for form, lemma in lemma_transforms: rule = self.create_lemma_rule(form, lemma) rules_all[(form, lemma)] = rule if create_rules: rules_merged[rule] += 1 # Keep the rules that are used more than once if create_rules: for rule, count in rules_merged.items(): if count > 1: lemmas.words_map[rule] = len(lemmas.words) lemmas.words.append(rule) # Store the rules in the dataset for i in range(len(forms.strings)): for j in range(len(forms.strings[i])): rule = rules_all.get((forms.strings[i][j], lemmas.strings[i][j])) lemmas.word_ids[i][j] = lemmas.words_map.get(rule, 0) def tokenize(self, tokenizer: transformers.PreTrainedTokenizer) -> tuple[list[np.ndarray], list[np.ndarray]]: if tokenizer not in self._tokenizer_cache: assert tokenizer.cls_token_id is not None, "The tokenizer must have a CLS token" tokenized = tokenizer(self.factors[0].strings, add_special_tokens=True, is_split_into_words=True) tokens, word_indices = [], [] for i, sentence in enumerate(tokenized.input_ids): offset = 0 if not len(sentence) or sentence[0] != tokenizer.cls_token_id: # Handle tokenizers that do not add CLS tokens, which we need for prediction # of the root nodes during parsing. For such tokenizers, we added the CLS token # manually already, but the build_inputs_with_special_tokens() might not have added it. sentence = [tokenizer.cls_token_id] + sentence offset = 1 treebank_id = None for id_, (start, end) in zip(self.treebank_ids, self.treebank_ranges): if start <= i < end: treebank_id = id_ if treebank_id is not None: sentence.insert(1, tokenizer.additional_special_tokens_ids[treebank_id]) offset += 1 tokens.append(np.array(sentence, dtype=np.int32)) word_indices.append([(0, 0)]) for j in range(len(self.factors[0].strings[i])): span = tokenized.word_to_tokens(i, j) word_indices[-1].append((offset + span.start, offset + span.end - 1)) word_indices[-1] = np.array(word_indices[-1], dtype=np.int32) self._tokenizer_cache[tokenizer] = (tokens, word_indices) return self._tokenizer_cache[tokenizer] def write_sentence(self, output: io.TextIOBase, index: int, overrides: list = None) -> None: assert index < len(self.factors[0].strings), "Sentence index out of range" for i in range(len(self.factors[0].strings[index]) + 1): # Start by writing extras if index < len(self._extras) and i < len(self._extras[index]) and self._extras[index][i]: print(self._extras[index][i], file=output) if i == len(self.factors[0].strings[index]): break fields = [] fields.append(str(i + 1)) for f in range(self.FACTORS): factor = self.factors[f] field = factor.strings[index][i] # Overrides if overrides is not None and f < len(overrides) and overrides[f] is not None: override = overrides[f][i] if f == self.HEAD: field = str(override) if override >= 0 else "_" else: field = factor.words[override] if f == self.LEMMAS: field = self.apply_lemma_rule(field, self.factors[self.FORMS].strings[index][i]) fields.append(field) print("\t".join(fields), file=output) print(file=output) class UDDatasetMerged(UDDataset): def __init__(self, datasets: list[UDDataset]): # Create factors and other variables self.factors = [] for f in range(self.FACTORS): self.factors.append(self.Factor(None)) lemma_transforms = collections.Counter() self.treebank_ranges, self.treebank_ids = [], [] for dataset in datasets: assert len(dataset.treebank_ranges) == len(dataset.treebank_ids) == 1 self.treebank_ranges.append((len(self), len(self) + len(dataset))) self.treebank_ids.append(dataset.treebank_ids[0]) for s in range(len(dataset)): for f in range(self.FACTORS): factor = self.factors[f] factor.strings.append(dataset.factors[f].strings[s]) factor.word_ids.append([]) for i, word in enumerate(dataset.factors[f].strings[s]): if f == self.FORMS: # We do not remap strings into IDs because the tokenizer will create the subwords IDs for us. factor.word_ids[-1].append(0) if f == self.HEAD: factor.word_ids[-1].append(word) elif f == self.LEMMAS: factor.word_ids[-1].append(0) lemma_transforms[(dataset.factors[self.FORMS].strings[s][i], word)] += 1 else: if word not in factor.words_map: factor.words_map[word] = len(factor.words) factor.words.append(word) factor.word_ids[-1].append(factor.words_map[word]) self.factors[f].word_ids[-1] = np.array(self.factors[f].word_ids[-1], np.int32) # Construct lemma rules self.finalize_lemma_rules(lemma_transforms, create_rules=True) # Create an empty tokenize cache self._tokenizer_cache = {} class TorchUDDataset(torch.utils.data.Dataset): def __init__(self, ud_dataset: UDDataset, tokenizers: list[transformers.PreTrainedTokenizer], args: argparse.Namespace, training: bool): self.ud_dataset = ud_dataset self.training = training self._outputs_to_input = [args.tags.index(tag) for tag in args.embed_tags] self._inputs = [ud_dataset.tokenize(tokenizer) for tokenizer in tokenizers] self._outputs = [ud_dataset.factors[tag].word_ids for tag in args.tags] if args.parse: self._outputs.append(ud_dataset.factors[ud_dataset.HEAD].word_ids) self._outputs.append(ud_dataset.factors[ud_dataset.DEPREL].word_ids) # Trim the sentences if needed if training and args.max_train_sentence_len: trimmed_sentences = 0 for index in range(len(self)): # Over sentences max_words, need_trimming = None, False for tokens, word_indices in self._inputs: # Over transformers if max_words is None: max_words = len(word_indices[index]) while word_indices[index][max_words - 1, 1] >= args.max_train_sentence_len: max_words -= 1 need_trimming = True assert max_words >= 2, "Sentence too short after trimming" if need_trimming: for tokens, word_indices in self._inputs: # Over transformers tokens[index] = tokens[index][:word_indices[index][max_words - 1, 1] + 1] word_indices[index] = word_indices[index][:max_words] for output in self._outputs: output[index] = output[index][:max_words - 1] # No CLS tokens in outputs if args.parse: self._outputs[-2][index] = np.array([head if head < max_words else -1 for head in self._outputs[-2][index]], np.int32) trimmed_sentences += 1 if trimmed_sentences: print("Trimmed {} out of {} sentences".format(trimmed_sentences, len(self))) def __len__(self): return len(self.ud_dataset) def __getitem__(self, index: int): inputs = [] for tokens, word_indices in self._inputs: inputs.append(torch.from_numpy(tokens[index])) inputs.append(torch.from_numpy(word_indices[index])) for i in self._outputs_to_input: inputs.append(torch.from_numpy(self._outputs[i][index])) outputs = [] for output in self._outputs: outputs.append(torch.from_numpy(output[index])) return inputs, outputs class TorchUDDataLoader(torch.utils.data.DataLoader): class MergedDatasetSampler(torch.utils.data.Sampler): def __init__(self, ud_dataset: UDDataset, args: argparse.Namespace): self._treebank_ranges = ud_dataset.treebank_ranges self._sentences_per_epoch = args.steps_per_epoch * args.batch_size self._generator = torch.Generator().manual_seed(args.seed) treebank_weights = np.array([r[1] - r[0] for r in self._treebank_ranges], np.float32) treebank_weights = treebank_weights ** args.train_sampling_exponent treebank_weights /= np.sum(treebank_weights) self._treebank_sizes = np.array(treebank_weights * self._sentences_per_epoch, np.int32) self._treebank_sizes[:self._sentences_per_epoch - np.sum(self._treebank_sizes)] += 1 self._treebank_indices = [[] for _ in self._treebank_ranges] def __len__(self): return self._sentences_per_epoch def __iter__(self): indices = [] for i in range(len(self._treebank_ranges)): required = self._treebank_sizes[i] while required: if not len(self._treebank_indices[i]): self._treebank_indices[i] = self._treebank_ranges[i][0] + torch.randperm( self._treebank_ranges[i][1] - self._treebank_ranges[i][0], generator=self._generator) indices.append(self._treebank_indices[i][:required]) required -= min(len(self._treebank_indices[i]), required) indices = torch.concatenate(indices, axis=0) return iter(indices[torch.randperm(len(indices), generator=self._generator)]) def _collate_fn(self, batch): inputs, outputs = zip(*batch) batch_inputs = [] for sequences in zip(*inputs): batch_inputs.append(torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=-1)) batch_outputs = [] for output in zip(*outputs): batch_outputs.append(torch.nn.utils.rnn.pad_sequence(output, batch_first=True, padding_value=-1)) batch_weights = [batch_output != -1 for batch_output in batch_outputs] return tuple(batch_inputs), tuple(batch_outputs), tuple(batch_weights) def __init__(self, dataset: TorchUDDataset, args: argparse.Namespace, **kwargs): sampler = None if dataset.training: if len(dataset.ud_dataset.treebank_ranges) == 1: sampler = torch.utils.data.RandomSampler(dataset, generator=torch.Generator().manual_seed(args.seed)) else: assert args.steps_per_epoch is not None, "Steps per epoch must be specified when training on multiple treebanks" sampler = self.MergedDatasetSampler(dataset.ud_dataset, args) super().__init__(dataset, batch_size=args.batch_size, sampler=sampler, collate_fn=self._collate_fn, **kwargs) class LatinPipeModel(keras.Model): class HFTransformerLayer(keras.layers.Layer): def __init__(self, transformer: transformers.PreTrainedModel, subword_combination: str, word_masking: float = None, mask_token_id: int = None, **kwargs): super().__init__(**kwargs) self._transformer = transformer self._subword_combination = subword_combination self._word_masking = word_masking self._mask_token_id = mask_token_id def call(self, inputs, word_indices, training=None): if training and self._word_masking: mask = keras.ops.cast(keras.random.uniform(keras.ops.shape(inputs), dtype="float32") < self._word_masking, inputs.dtype) inputs = (1 - mask) * inputs + mask * self._mask_token_id if (training or False) != self._transformer.training: self._transformer.train(training or False) if self._subword_combination != "last": first_subwords = keras.ops.take_along_axis( self._transformer(keras.ops.maximum(inputs, 0), attention_mask=inputs > -1).last_hidden_state, keras.ops.expand_dims(keras.ops.maximum(word_indices[..., 0], 0), axis=-1), axis=1, ) if self._subword_combination != "first": last_subwords = keras.ops.take_along_axis( self._transformer(keras.ops.maximum(inputs, 0), attention_mask=inputs > -1).last_hidden_state, keras.ops.expand_dims(keras.ops.maximum(word_indices[..., 1], 0), axis=-1), axis=1, ) if self._subword_combination == "first": return first_subwords elif self._subword_combination == "last": return last_subwords elif self._subword_combination == "sum": return first_subwords + last_subwords elif self._subword_combination == "concat": return keras.ops.concatenate([first_subwords, last_subwords], axis=-1) else: raise ValueError("Unknown subword combination '{}'".format(self._subword_combination)) class LSTMTorch(keras.layers.Layer): def __init__(self, units: int, **kwargs): super().__init__(**kwargs) self._units = units def build(self, input_shape): self._lstm = torch.nn.LSTM(input_shape[-1], self._units, batch_first=True, bidirectional=True) def call(self, inputs, lengths): packed_result, _ = self._lstm.module(torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths.cpu(), batch_first=True, enforce_sorted=False)) unpacked_result = torch.nn.utils.rnn.unpack_sequence(packed_result) return torch.nn.utils.rnn.pad_sequence(unpacked_result, batch_first=True, padding_value=0) class GRUTorch(keras.layers.Layer): def __init__(self, units: int, **kwargs): super().__init__(**kwargs) self._units = units def build(self, input_shape): self._gru = torch.nn.GRU(input_shape[-1], self._units, batch_first=True, bidirectional=True) def call(self, inputs, lengths): packed_result, _ = self._gru(torch.nn.utils.rnn.pack_padded_sequence(inputs, lengths.cpu(), batch_first=True, enforce_sorted=False)) unpacked_result = torch.nn.utils.rnn.unpack_sequence(packed_result) return torch.nn.utils.rnn.pad_sequence(unpacked_result, batch_first=True, padding_value=0) class ParsingHead(keras.layers.Layer): def __init__(self, num_deprels: int, task_hidden_layer: int, parse_attention_dim: int, dropout: float, **kwargs): super().__init__(**kwargs) self._head_queries_hidden = keras.layers.Dense(task_hidden_layer, activation="relu") self._head_queries_output = keras.layers.Dense(parse_attention_dim) self._head_keys_hidden = keras.layers.Dense(task_hidden_layer, activation="relu") self._head_keys_output = keras.layers.Dense(parse_attention_dim) self._deprel_hidden = keras.layers.Dense(task_hidden_layer, activation="relu") self._deprel_output = keras.layers.Dense(num_deprels) self._dropout = keras.layers.Dropout(dropout) def call(self, embeddings, embeddings_wo_root, embeddings_mask): head_queries = self._head_queries_output(self._dropout(self._head_queries_hidden(embeddings_wo_root))) head_keys = self._head_keys_output(self._dropout(self._head_keys_hidden(embeddings))) head_scores = keras.ops.matmul(head_queries, keras.ops.transpose(head_keys, axes=[0, 2, 1])) / keras.ops.sqrt(head_queries.shape[-1]) head_scores_mask = keras.ops.cast(keras.ops.expand_dims(embeddings_mask, axis=1), head_scores.dtype) head_scores = head_scores * head_scores_mask - 1e9 * (1 - head_scores_mask) predicted_heads = keras.ops.argmax(head_scores, axis=-1) predicted_head_embeddings = keras.ops.take_along_axis(embeddings, keras.ops.expand_dims(predicted_heads, axis=-1), axis=1) deprel_hidden = keras.ops.concatenate([embeddings_wo_root, predicted_head_embeddings], axis=-1) deprel_scores = self._deprel_output(self._dropout(self._deprel_hidden(deprel_hidden))) return head_scores, deprel_scores class SparseCategoricalCrossentropyWithLabelSmoothing(keras.losses.Loss): def __init__(self, from_logits: bool, label_smoothing: float, **kwargs): super().__init__(**kwargs) self._from_logits = from_logits self._label_smoothing = label_smoothing def call(self, y_true, y_pred): y_gold = keras.ops.one_hot(keras.ops.maximum(y_true, 0), y_pred.shape[-1]) if self._label_smoothing: y_pred_mask = keras.ops.cast(y_pred > -1e9, y_pred.dtype) y_gold = y_gold * (1 - self._label_smoothing) + y_pred_mask / keras.ops.sum(y_pred_mask, axis=-1, keepdims=True) * self._label_smoothing return keras.losses.categorical_crossentropy(y_gold, y_pred, from_logits=self._from_logits) def __init__(self, dataset: UDDataset, args: argparse.Namespace): self._dataset = dataset self._args = args # Create the transformer models self._tokenizers, self._transformers = [], [] for name in args.transformers: self._tokenizers.append(transformers.AutoTokenizer.from_pretrained(name, add_prefix_space=True)) transformer, transformer_opts = transformers.AutoModel, {} if "mt5" in name.lower(): transformer = transformers.MT5EncoderModel if name.endswith(("LaTa", "PhilTa")): transformer = transformers.T5EncoderModel if name.endswith(("LaBerta", "PhilBerta")): transformer_opts["add_pooling_layer"] = False if args.load: transformer = transformer.from_config(transformers.AutoConfig.from_pretrained(name), **transformer_opts) else: transformer = transformer.from_pretrained(name, **transformer_opts) # Create additional tokens additional_tokens = {} if args.treebank_ids: additional_tokens["additional_special_tokens"] = ["[TREEBANK_ID_{}]".format(i) for i in range(len(dataset.treebank_ids))] if self._tokenizers[-1].cls_token_id is None: # Generate CLS token if not present (for representing sentence root in parsing). additional_tokens["cls_token"] = "[CLS]" if additional_tokens: self._tokenizers[-1].add_special_tokens(additional_tokens) transformer.resize_token_embeddings(len(self._tokenizers[-1])) if args.treebank_ids: assert len(self._tokenizers[-1].additional_special_tokens) == len(dataset.treebank_ids) self._transformers.append(self.HFTransformerLayer(transformer, args.subword_combination, args.word_masking, self._tokenizers[-1].mask_token_id)) # Create the network inputs = [] for _ in args.transformers: inputs.extend([keras.layers.Input(shape=[None], dtype="int32"), keras.layers.Input(shape=[None, 2], dtype="int32")]) for _ in args.embed_tags: inputs.append(keras.layers.Input(shape=[None], dtype="int32")) # Run the transformer models embeddings = [] for tokens, word_indices, transformer in zip(inputs[::2], inputs[1::2], self._transformers): embeddings.append(transformer(tokens, word_indices)) embeddings = keras.layers.Concatenate(axis=-1)(embeddings) embeddings = keras.layers.Dropout(args.dropout)(embeddings) # Heads for the tagging tasks outputs = [] for tag in args.tags: hidden = keras.layers.Dense(args.task_hidden_layer, activation="relu")(embeddings[:, 1:]) hidden = keras.layers.Dropout(args.dropout)(hidden) outputs.append(keras.layers.Dense(len(dataset.factors[tag].words))(hidden)) # Head for parsing if args.parse: if args.embed_tags: all_embeddings = [embeddings] for factor, input_tags in zip(args.embed_tags, inputs[-len(args.embed_tags):]): embedding_layer = keras.layers.Embedding(len(dataset.factors[factor].words) + 1, 256) all_embeddings.append(keras.layers.Dropout(args.dropout)(embedding_layer(keras.ops.pad(input_tags + 1, [(0, 0), (1, 0)])))) embeddings = keras.ops.concatenate(all_embeddings, axis=-1) for i in range(args.rnn_layers): if args.rnn_type in ["LSTM", "GRU"]: hidden = keras.layers.Bidirectional(getattr(keras.layers, args.rnn_type)(args.rnn_dim, return_sequences=True))(embeddings, mask=inputs[1][..., 0] > -1) elif args.rnn_type in ["LSTMTorch", "GRUTorch"]: hidden = getattr(self, args.rnn_type)(args.rnn_dim)(embeddings, keras.ops.sum(inputs[1][..., 0] > -1, axis=-1)) hidden = keras.layers.Dropout(args.dropout)(hidden) embeddings = hidden + (embeddings if i else 0) outputs.extend(self.ParsingHead( len(dataset.factors[dataset.DEPREL].words), args.task_hidden_layer, args.parse_attention_dim, args.dropout, )(embeddings, embeddings[:, 1:], inputs[1][..., 0] > -1)) super().__init__(inputs=inputs, outputs=outputs) if args.load: self.load_weights(args.load[0]) def compile(self, epoch_batches: int, frozen: bool): args = self._args for transformer in self._transformers: transformer.trainable = not frozen if frozen: schedule = 1e-3 else: schedule = keras.optimizers.schedules.CosineDecay( 0. if args.learning_rate_warmup else args.learning_rate, args.epochs * epoch_batches - args.learning_rate_warmup, alpha=0.0 if args.learning_rate_decay != "none" else 1.0, warmup_target=args.learning_rate if args.learning_rate_warmup else None, warmup_steps=args.learning_rate_warmup, ) if args.optimizer == "adam": optimizer = keras.optimizers.Adam(schedule) elif args.optimizer == "adafactor": optimizer = keras.optimizers.Adafactor(schedule) else: raise ValueError("Unknown optimizer '{}'".format(args.optimizer)) super().compile( optimizer=optimizer, loss=self.SparseCategoricalCrossentropyWithLabelSmoothing(from_logits=True, label_smoothing=args.label_smoothing), ) @property def tokenizers(self) -> list[transformers.PreTrainedTokenizer]: return self._tokenizers def predict(self, dataloader: TorchUDDataLoader, save_as: str|None = None, args_override: argparse.Namespace|None = None) -> str: ud_dataset = dataloader.dataset.ud_dataset args = self._args if args_override is None else args_override conllu, sentence = io.StringIO(), 0 for batch_inputs, _, _ in dataloader: predictions = self.predict_on_batch(batch_inputs) for b in range(len(batch_inputs[0])): sentence_len = len(ud_dataset.factors[ud_dataset.FORMS].strings[sentence]) overrides = [None] * ud_dataset.FACTORS for tag, prediction in zip(args.tags, predictions): overrides[tag] = np.argmax(prediction[b, :sentence_len], axis=-1) if args.parse: heads, deprels = predictions[-2:] padded_heads = np.zeros([sentence_len + 1, sentence_len + 1], dtype=np.float64) padded_heads[1:] = heads[b, :sentence_len, :sentence_len + 1] padded_heads[1:] -= np.max(padded_heads[1:], axis=-1, keepdims=True) padded_heads[1:] -= np.log(np.sum(np.exp(padded_heads[1:]), axis=-1, keepdims=True)) if args.single_root: selected_root = 1 + np.argmax(padded_heads[1:, 0]) padded_heads[:, 0] = np.nan padded_heads[selected_root, 0] = 0 chosen_heads, _ = ufal.chu_liu_edmonds.chu_liu_edmonds(padded_heads) overrides[ud_dataset.HEAD] = chosen_heads[1:] overrides[ud_dataset.DEPREL] = np.argmax(deprels[b, :sentence_len], axis=-1) ud_dataset.write_sentence(conllu, sentence, overrides) sentence += 1 conllu = conllu.getvalue() if save_as is not None: os.makedirs(os.path.dirname(save_as), exist_ok=True) with open(save_as, "w", encoding="utf-8") as conllu_file: conllu_file.write(conllu) return conllu def evaluate(self, dataloader: TorchUDDataLoader, save_as: str|None = None, args_override: argparse.Namespace|None = None) -> tuple[str, dict[str, float]]: conllu = self.predict(dataloader, save_as=save_as, args_override=args_override) evaluation = latinpipe_evalatin24_eval.evaluate(dataloader.dataset.ud_dataset.conllu_for_eval, latinpipe_evalatin24_eval.load_conllu(io.StringIO(conllu))) if save_as is not None: os.makedirs(os.path.dirname(save_as), exist_ok=True) with open(save_as + ".eval", "w", encoding="utf-8") as eval_file: for metric, score in evaluation.items(): print("{}: {:.2f}%".format(metric, 100 * score.f1), file=eval_file) return conllu, evaluation class LatinPipeModelEnsemble: def __init__(self, latinpipe_model: LatinPipeModel, args: argparse.Namespace): self._latinpipe_model = latinpipe_model self._args = args def predict(self, dataloader: TorchUDDataLoader, save_as: str|None = None) -> str: def log_softmax(logits): logits -= np.max(logits, axis=-1, keepdims=True) logits -= np.log(np.sum(np.exp(logits), axis=-1, keepdims=True)) return logits ud_dataset = dataloader.dataset.ud_dataset # First compute all predictions overrides = [[0] * len(ud_dataset) if tag in self._args.tags + ([ud_dataset.HEAD, ud_dataset.DEPREL] if self._args.parse else []) else None for tag in range(ud_dataset.FACTORS)] for path in self._args.load: self._latinpipe_model.load_weights(path) sentence = 0 for batch_inputs, _, _ in dataloader: predictions = self._latinpipe_model.predict_on_batch(batch_inputs) for b in range(len(batch_inputs[0])): sentence_len = len(ud_dataset.factors[ud_dataset.FORMS].strings[sentence]) for tag, prediction in zip(self._args.tags, predictions): overrides[tag][sentence] += log_softmax(prediction[b, :sentence_len]) if self._args.parse: overrides[ud_dataset.HEAD][sentence] += log_softmax(predictions[-2][b, :sentence_len, :sentence_len + 1]) overrides[ud_dataset.DEPREL][sentence] += log_softmax(predictions[-1][b, :sentence_len]) sentence += 1 # Predict the most likely class and generate CoNLL-U output conllu = io.StringIO() for sentence in range(len(ud_dataset)): sentence_overrides = [None] * ud_dataset.FACTORS for tag in self._args.tags: sentence_overrides[tag] = np.argmax(overrides[tag][sentence], axis=-1) if self._args.parse: padded_heads = np.pad(overrides[ud_dataset.HEAD][sentence], [(1, 0), (0, 0)]).astype(np.float64) if self._args.single_root: selected_root = 1 + np.argmax(padded_heads[1:, 0]) padded_heads[:, 0] = np.nan padded_heads[selected_root, 0] = 0 chosen_heads, _ = ufal.chu_liu_edmonds.chu_liu_edmonds(padded_heads) sentence_overrides[ud_dataset.HEAD] = chosen_heads[1:] sentence_overrides[ud_dataset.DEPREL] = np.argmax(overrides[ud_dataset.DEPREL][sentence], axis=-1) ud_dataset.write_sentence(conllu, sentence, sentence_overrides) conllu = conllu.getvalue() if save_as is not None: os.makedirs(os.path.dirname(save_as), exist_ok=True) with open(save_as, "w", encoding="utf-8") as conllu_file: conllu_file.write(conllu) return conllu def evaluate(self, dataloader: TorchUDDataLoader, save_as: str|None = None) -> tuple[str, dict[str, float]]: return LatinPipeModel.evaluate(self, dataloader, save_as=save_as) def main(params: list[str] | None = None) -> None: args = parser.parse_args(params) # If supplied, load configuration from a trained model if args.load: with open(os.path.join(os.path.dirname(args.load[0]), "options.json"), mode="r") as options_file: args = argparse.Namespace(**{k: v for k, v in json.load(options_file).items() if k not in [ "dev", "exp", "load", "test", "threads", "verbose"]}) args = parser.parse_args(params, namespace=args) else: assert args.train, "Either --load or --train must be set." assert args.transformers, "At least one transformer must be specified." # Post-process arguments args.embed_tags = [UDDataset.FACTORS_MAP[tag] for tag in args.embed_tags.split(",") if tag] args.tags = [UDDataset.FACTORS_MAP[tag] for tag in args.tags.split(",") if tag] args.script = os.path.basename(__file__) # Create logdir args.logdir = os.path.join("logs", "{}{}-{}-{}-s{}".format( args.exp + "-" if args.exp else "", os.path.splitext(os.path.basename(globals().get("__file__", "notebook")))[0], os.environ.get("SLURM_JOB_ID", ""), datetime.datetime.now().strftime("%y%m%d_%H%M%S"), args.seed, # ",".join(("{}={}".format( # re.sub("(.)[^_]*_?", r"\1", k), # ",".join(re.sub(r"^.*/", "", str(x)) for x in ((v if len(v) <= 1 else [v[0], "..."]) if isinstance(v, list) else [v])), # ) for k, v in sorted(vars(args).items()) if k not in ["dev", "exp", "load", "test", "threads", "verbose"])) )) print(json.dumps(vars(args), sort_keys=True, ensure_ascii=False, indent=2)) os.makedirs(args.logdir, exist_ok=True) with open(os.path.join(args.logdir, "options.json"), mode="w") as options_file: json.dump(vars(args), options_file, sort_keys=True, ensure_ascii=False, indent=2) # Set the random seed and the number of threads keras.utils.set_random_seed(args.seed) torch.set_num_threads(args.threads) torch.set_num_interop_threads(args.threads) # Load the data if args.treebank_ids and max(len(args.train), len(args.dev), len(args.test)) > 1: print("WARNING: With treebank_ids, treebanks must always be in the same position in the train/dev/test.") if args.load: train = UDDataset.from_mappings(os.path.join(os.path.dirname(args.load[0]), "mappings.pkl")) else: train = UDDatasetMerged([UDDataset(path, args, treebank_id=i if args.treebank_ids else None) for i, path in enumerate(args.train)]) train.save_mappings(os.path.join(args.logdir, "mappings.pkl")) devs = [UDDataset(path, args, treebank_id=i if args.treebank_ids else None, train_dataset=train) for i, path in enumerate(args.dev)] tests = [UDDataset(path, args, treebank_id=i if args.treebank_ids else None, train_dataset=train) for i, path in enumerate(args.test)] # Create the model model = LatinPipeModel(train, args) # Create the dataloaders if not args.load: train_dataloader = TorchUDDataLoader(TorchUDDataset(train, model.tokenizers, args, training=True), args) dev_dataloaders = [TorchUDDataLoader(TorchUDDataset(dataset, model.tokenizers, args, training=False), args) for dataset in devs] test_dataloaders = [TorchUDDataLoader(TorchUDDataset(dataset, model.tokenizers, args, training=False), args) for dataset in tests] # Perform prediction if requested if args.load: if len(args.load) > 1: model = LatinPipeModelEnsemble(model, args) for dataloader in dev_dataloaders: model.evaluate(dataloader, save_as=os.path.splitext( os.path.join(args.exp, os.path.basename(dataloader.dataset.ud_dataset.path)) if args.exp else dataloader.dataset.ud_dataset.path )[0] + ".predicted.conllu") for dataloader in test_dataloaders: model.predict(dataloader, save_as=os.path.splitext( os.path.join(args.exp, os.path.basename(dataloader.dataset.ud_dataset.path)) if args.exp else dataloader.dataset.ud_dataset.path )[0] + ".predicted.conllu") return # Train the model class Evaluator(keras.callbacks.Callback): def __init__(self, wandb_log): super().__init__() self._wandb_log = wandb_log self._metrics = [["", "Lemmas", "UPOS", "XPOS", "UFeats"][tag] for tag in args.tags] + (["UAS", "LAS"] if args.parse else []) def on_epoch_end(self, epoch, logs=None): logs["learning_rate"] = keras.ops.convert_to_numpy(model.optimizer.learning_rate) for dataloader in dev_dataloaders + (test_dataloaders if epoch + 1 == args.epochs + args.epochs_frozen else []): _, metrics = model.evaluate(dataloader, save_as=os.path.splitext( os.path.join(args.logdir, os.path.basename(dataloader.dataset.ud_dataset.path)) )[0] + ".{:02d}.conllu".format(epoch + 1)) for metric, score in metrics.items(): if metric in self._metrics: logs["{}_{}".format(os.path.splitext(os.path.basename(dataloader.dataset.ud_dataset.path))[0], metric)] = 100 * score.f1 aggregations = {"la_ud213": [("la_ittb-ud", 390_787), ("la_llct-ud", 194_143), ("la_proiel-ud", 177_558), ("la_udante-ud", 30_450), ("la_perseus-ud", 16_486)]} for split in ["dev", "test"]: for metric in self._metrics: for aggregation, parts in aggregations.items(): values = [logs.get("{}-{}_{}".format(part, split, metric), None) for part, _ in parts] if all(value is not None for value in values): logs["{}-{}_{}".format(aggregation, split, metric)] = np.mean(values) logs["{}-sqrt-{}_{}".format(aggregation, split, metric)] = np.average(values, weights=[size**0.5 for _, size in parts]) if self._wandb_log: self._wandb_log(logs, step=epoch + 1, commit=True) wandb_log = None if args.wandb: import wandb wandb.init(project="ufal-evalatin2024", name=args.exp, config=vars(args)) wandb_log = wandb.log evaluator = Evaluator(wandb_log) if args.epochs_frozen: model.compile(len(train_dataloader), frozen=True) model.fit(train_dataloader, epochs=args.epochs_frozen, verbose=args.verbose, callbacks=[evaluator]) if args.epochs: model.compile(len(train_dataloader), frozen=False) model.fit(train_dataloader, initial_epoch=args.epochs_frozen, epochs=args.epochs_frozen + args.epochs, verbose=args.verbose, callbacks=[evaluator]) if args.save_checkpoint: model.save_weights(os.path.join(args.logdir, "model.weights.h5")) if __name__ == "__main__": main([] if "__file__" not in globals() else None)