| | from torch import nn |
| | from torch import LongTensor |
| | from transformers import PreTrainedModel |
| |
|
| | from .configuration import CobaldParserConfig |
| | from .encoder import WordTransformerEncoder |
| | from .mlp_classifier import MlpClassifier |
| | from .dependency_classifier import DependencyClassifier |
| | from .utils import ( |
| | build_padding_mask, |
| | build_null_mask, |
| | prepend_cls, |
| | remove_nulls, |
| | add_nulls |
| | ) |
| |
|
| |
|
| | class CobaldParser(PreTrainedModel): |
| | """Morpho-Syntax-Semantic Parser.""" |
| |
|
| | config_class = CobaldParserConfig |
| |
|
| | def __init__(self, config: CobaldParserConfig): |
| | super().__init__(config) |
| |
|
| | self.encoder = WordTransformerEncoder( |
| | model_name=config.encoder_model_name |
| | ) |
| | embedding_size = self.encoder.get_embedding_size() |
| |
|
| | self.classifiers = nn.ModuleDict() |
| | self.classifiers["null"] = MlpClassifier( |
| | input_size=self.encoder.get_embedding_size(), |
| | hidden_size=config.null_classifier_hidden_size, |
| | n_classes=config.consecutive_null_limit + 1, |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "lemma_rule" in config.vocabulary: |
| | self.classifiers["lemma_rule"] = MlpClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.lemma_classifier_hidden_size, |
| | n_classes=len(config.vocabulary["lemma_rule"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "joint_feats" in config.vocabulary: |
| | self.classifiers["joint_feats"] = MlpClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.morphology_classifier_hidden_size, |
| | n_classes=len(config.vocabulary["joint_feats"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "ud_deprel" in config.vocabulary or "eud_deprel" in config.vocabulary: |
| | self.classifiers["syntax"] = DependencyClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.dependency_classifier_hidden_size, |
| | n_rels_ud=len(config.vocabulary["ud_deprel"]), |
| | n_rels_eud=len(config.vocabulary["eud_deprel"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "misc" in config.vocabulary: |
| | self.classifiers["misc"] = MlpClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.misc_classifier_hidden_size, |
| | n_classes=len(config.vocabulary["misc"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "deepslot" in config.vocabulary: |
| | self.classifiers["deepslot"] = MlpClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.deepslot_classifier_hidden_size, |
| | n_classes=len(config.vocabulary["deepslot"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| | if "semclass" in config.vocabulary: |
| | self.classifiers["semclass"] = MlpClassifier( |
| | input_size=embedding_size, |
| | hidden_size=config.semclass_classifier_hidden_size, |
| | n_classes=len(config.vocabulary["semclass"]), |
| | activation=config.activation, |
| | dropout=config.dropout |
| | ) |
| |
|
| | def forward( |
| | self, |
| | words: list[list[str]], |
| | counting_masks: LongTensor = None, |
| | lemma_rules: LongTensor = None, |
| | joint_feats: LongTensor = None, |
| | deps_ud: LongTensor = None, |
| | deps_eud: LongTensor = None, |
| | miscs: LongTensor = None, |
| | deepslots: LongTensor = None, |
| | semclasses: LongTensor = None, |
| | sent_ids: list[str] = None, |
| | texts: list[str] = None, |
| | inference_mode: bool = False |
| | ) -> dict: |
| | output = {} |
| |
|
| | |
| | words_with_cls = prepend_cls(words) |
| | words_without_nulls = remove_nulls(words_with_cls) |
| | |
| | embeddings_without_nulls = self.encoder(words_without_nulls) |
| | |
| | null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks) |
| | output["counting_mask"] = null_output['preds'] |
| | output["loss"] = null_output["loss"] |
| |
|
| | |
| | |
| | if inference_mode: |
| | |
| | output["words"] = add_nulls(words, null_output["preds"]) |
| | else: |
| | output["words"] = words |
| |
|
| | |
| | |
| | embeddings = self.encoder(output["words"]) |
| |
|
| | |
| | if "lemma_rule" in self.classifiers: |
| | lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules) |
| | output["lemma_rules"] = lemma_output['preds'] |
| | output["loss"] += lemma_output['loss'] |
| |
|
| | if "joint_feats" in self.classifiers: |
| | joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats) |
| | output["joint_feats"] = joint_feats_output['preds'] |
| | output["loss"] += joint_feats_output['loss'] |
| |
|
| | |
| | if "syntax" in self.classifiers: |
| | padding_mask = build_padding_mask(output["words"], self.device) |
| | null_mask = build_null_mask(output["words"], self.device) |
| | deps_output = self.classifiers["syntax"]( |
| | embeddings, |
| | deps_ud, |
| | deps_eud, |
| | null_mask, |
| | padding_mask |
| | ) |
| | output["deps_ud"] = deps_output['preds_ud'] |
| | output["deps_eud"] = deps_output['preds_eud'] |
| | output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud'] |
| |
|
| | |
| | if "misc" in self.classifiers: |
| | misc_output = self.classifiers["misc"](embeddings, miscs) |
| | output["miscs"] = misc_output['preds'] |
| | output["loss"] += misc_output['loss'] |
| |
|
| | |
| | if "deepslot" in self.classifiers: |
| | deepslot_output = self.classifiers["deepslot"](embeddings, deepslots) |
| | output["deepslots"] = deepslot_output['preds'] |
| | output["loss"] += deepslot_output['loss'] |
| |
|
| | if "semclass" in self.classifiers: |
| | semclass_output = self.classifiers["semclass"](embeddings, semclasses) |
| | output["semclasses"] = semclass_output['preds'] |
| | output["loss"] += semclass_output['loss'] |
| |
|
| | return output |