encoder_freezed_70epochs_10e-5 / modeling_parser.py
E-katrin's picture
Upload ConlluTokenClassificationPipeline
808715d verified
from torch import nn
from torch import LongTensor
from transformers import PreTrainedModel
from .configuration import CobaldParserConfig
from .encoder import WordTransformerEncoder
from .mlp_classifier import MlpClassifier
from .dependency_classifier import DependencyClassifier
from .utils import (
build_padding_mask,
build_null_mask,
prepend_cls,
remove_nulls,
add_nulls
)
class CobaldParser(PreTrainedModel):
"""Morpho-Syntax-Semantic Parser."""
config_class = CobaldParserConfig
def __init__(self, config: CobaldParserConfig):
super().__init__(config)
self.encoder = WordTransformerEncoder(
model_name=config.encoder_model_name
)
embedding_size = self.encoder.get_embedding_size()
self.classifiers = nn.ModuleDict()
self.classifiers["null"] = MlpClassifier(
input_size=self.encoder.get_embedding_size(),
hidden_size=config.null_classifier_hidden_size,
n_classes=config.consecutive_null_limit + 1,
activation=config.activation,
dropout=config.dropout
)
if "lemma_rule" in config.vocabulary:
self.classifiers["lemma_rule"] = MlpClassifier(
input_size=embedding_size,
hidden_size=config.lemma_classifier_hidden_size,
n_classes=len(config.vocabulary["lemma_rule"]),
activation=config.activation,
dropout=config.dropout
)
if "joint_feats" in config.vocabulary:
self.classifiers["joint_feats"] = MlpClassifier(
input_size=embedding_size,
hidden_size=config.morphology_classifier_hidden_size,
n_classes=len(config.vocabulary["joint_feats"]),
activation=config.activation,
dropout=config.dropout
)
if "ud_deprel" in config.vocabulary or "eud_deprel" in config.vocabulary:
self.classifiers["syntax"] = DependencyClassifier(
input_size=embedding_size,
hidden_size=config.dependency_classifier_hidden_size,
n_rels_ud=len(config.vocabulary["ud_deprel"]),
n_rels_eud=len(config.vocabulary["eud_deprel"]),
activation=config.activation,
dropout=config.dropout
)
if "misc" in config.vocabulary:
self.classifiers["misc"] = MlpClassifier(
input_size=embedding_size,
hidden_size=config.misc_classifier_hidden_size,
n_classes=len(config.vocabulary["misc"]),
activation=config.activation,
dropout=config.dropout
)
if "deepslot" in config.vocabulary:
self.classifiers["deepslot"] = MlpClassifier(
input_size=embedding_size,
hidden_size=config.deepslot_classifier_hidden_size,
n_classes=len(config.vocabulary["deepslot"]),
activation=config.activation,
dropout=config.dropout
)
if "semclass" in config.vocabulary:
self.classifiers["semclass"] = MlpClassifier(
input_size=embedding_size,
hidden_size=config.semclass_classifier_hidden_size,
n_classes=len(config.vocabulary["semclass"]),
activation=config.activation,
dropout=config.dropout
)
def forward(
self,
words: list[list[str]],
counting_masks: LongTensor = None,
lemma_rules: LongTensor = None,
joint_feats: LongTensor = None,
deps_ud: LongTensor = None,
deps_eud: LongTensor = None,
miscs: LongTensor = None,
deepslots: LongTensor = None,
semclasses: LongTensor = None,
sent_ids: list[str] = None,
texts: list[str] = None,
inference_mode: bool = False
) -> dict:
output = {}
# Extra [CLS] token accounts for the case when #NULL is the first token in a sentence.
words_with_cls = prepend_cls(words)
words_without_nulls = remove_nulls(words_with_cls)
# Embeddings of words without nulls.
embeddings_without_nulls = self.encoder(words_without_nulls)
# Predict nulls.
null_output = self.classifiers["null"](embeddings_without_nulls, counting_masks)
output["counting_mask"] = null_output['preds']
output["loss"] = null_output["loss"]
# "Teacher forcing": during training, pass the original words (with gold nulls)
# to the classification heads, so that they are trained upon correct sentences.
if inference_mode:
# Restore predicted nulls in the original sentences.
output["words"] = add_nulls(words, null_output["preds"])
else:
output["words"] = words
# Encode words with nulls.
# [batch_size, seq_len, embedding_size]
embeddings = self.encoder(output["words"])
# Predict lemmas and morphological features.
if "lemma_rule" in self.classifiers:
lemma_output = self.classifiers["lemma_rule"](embeddings, lemma_rules)
output["lemma_rules"] = lemma_output['preds']
output["loss"] += lemma_output['loss']
if "joint_feats" in self.classifiers:
joint_feats_output = self.classifiers["joint_feats"](embeddings, joint_feats)
output["joint_feats"] = joint_feats_output['preds']
output["loss"] += joint_feats_output['loss']
# Predict syntax.
if "syntax" in self.classifiers:
padding_mask = build_padding_mask(output["words"], self.device)
null_mask = build_null_mask(output["words"], self.device)
deps_output = self.classifiers["syntax"](
embeddings,
deps_ud,
deps_eud,
null_mask,
padding_mask
)
output["deps_ud"] = deps_output['preds_ud']
output["deps_eud"] = deps_output['preds_eud']
output["loss"] += deps_output['loss_ud'] + deps_output['loss_eud']
# Predict miscellaneous features.
if "misc" in self.classifiers:
misc_output = self.classifiers["misc"](embeddings, miscs)
output["miscs"] = misc_output['preds']
output["loss"] += misc_output['loss']
# Predict semantics.
if "deepslot" in self.classifiers:
deepslot_output = self.classifiers["deepslot"](embeddings, deepslots)
output["deepslots"] = deepslot_output['preds']
output["loss"] += deepslot_output['loss']
if "semclass" in self.classifiers:
semclass_output = self.classifiers["semclass"](embeddings, semclasses)
output["semclasses"] = semclass_output['preds']
output["loss"] += semclass_output['loss']
return output