|
|
""" |
|
|
A classifier that uses a constituency parser for the base embeddings |
|
|
""" |
|
|
|
|
|
import dataclasses |
|
|
import logging |
|
|
from types import SimpleNamespace |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from stanza.models.classifiers.base_classifier import BaseClassifier |
|
|
from stanza.models.classifiers.config import ConstituencyConfig |
|
|
from stanza.models.classifiers.data import SentimentDatum |
|
|
from stanza.models.classifiers.utils import ModelType, build_output_layers |
|
|
|
|
|
from stanza.models.common.utils import split_into_batches, sort_with_indices, unsort |
|
|
|
|
|
logger = logging.getLogger('stanza') |
|
|
tlogger = logging.getLogger('stanza.classifiers.trainer') |
|
|
|
|
|
class ConstituencyClassifier(BaseClassifier): |
|
|
def __init__(self, tree_embedding, labels, args): |
|
|
super(ConstituencyClassifier, self).__init__() |
|
|
self.labels = labels |
|
|
|
|
|
self.config = ConstituencyConfig(fc_shapes = args.fc_shapes, |
|
|
dropout = args.dropout, |
|
|
num_classes = len(labels), |
|
|
constituency_backprop = args.constituency_backprop, |
|
|
constituency_batch_norm = args.constituency_batch_norm, |
|
|
constituency_node_attn = args.constituency_node_attn, |
|
|
constituency_top_layer = args.constituency_top_layer, |
|
|
constituency_all_words = args.constituency_all_words, |
|
|
model_type = ModelType.CONSTITUENCY) |
|
|
|
|
|
self.tree_embedding = tree_embedding |
|
|
|
|
|
self.fc_layers = build_output_layers(self.tree_embedding.output_size, self.config.fc_shapes, self.config.num_classes) |
|
|
self.dropout = nn.Dropout(self.config.dropout) |
|
|
|
|
|
def is_unsaved_module(self, name): |
|
|
return False |
|
|
|
|
|
def log_configuration(self): |
|
|
tlogger.info("Backprop into parser: %s", self.config.constituency_backprop) |
|
|
tlogger.info("Batch norm: %s", self.config.constituency_batch_norm) |
|
|
tlogger.info("Word positions used: %s", "all words" if self.config.constituency_all_words else "start and end words") |
|
|
tlogger.info("Attention over nodes: %s", self.config.constituency_node_attn) |
|
|
tlogger.info("Intermediate layers: %s", self.config.fc_shapes) |
|
|
|
|
|
def log_norms(self): |
|
|
lines = ["NORMS FOR MODEL PARAMTERS"] |
|
|
lines.extend(["tree_embedding." + x for x in self.tree_embedding.get_norms()]) |
|
|
for name, param in self.named_parameters(): |
|
|
if param.requires_grad and not name.startswith('tree_embedding.'): |
|
|
lines.append("%s %.6g" % (name, torch.norm(param).item())) |
|
|
logger.info("\n".join(lines)) |
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
inputs = [x.constituency if isinstance(x, SentimentDatum) else x for x in inputs] |
|
|
|
|
|
embedding = self.tree_embedding.embed_trees(inputs) |
|
|
previous_layer = torch.stack([torch.max(x, dim=0)[0] for x in embedding], dim=0) |
|
|
previous_layer = self.dropout(previous_layer) |
|
|
for fc in self.fc_layers[:-1]: |
|
|
|
|
|
previous_layer = self.dropout(F.gelu(fc(previous_layer))) |
|
|
out = self.fc_layers[-1](previous_layer) |
|
|
return out |
|
|
|
|
|
def get_params(self, skip_modules=True): |
|
|
model_state = self.state_dict() |
|
|
|
|
|
|
|
|
skipped = [k for k in model_state.keys() if k.startswith("tree_embedding.")] |
|
|
for k in skipped: |
|
|
del model_state[k] |
|
|
|
|
|
tree_embedding = self.tree_embedding.get_params(skip_modules) |
|
|
|
|
|
config = dataclasses.asdict(self.config) |
|
|
config['model_type'] = config['model_type'].name |
|
|
|
|
|
params = { |
|
|
'model': model_state, |
|
|
'tree_embedding': tree_embedding, |
|
|
'config': config, |
|
|
'labels': self.labels, |
|
|
} |
|
|
return params |
|
|
|
|
|
def extract_sentences(self, doc): |
|
|
return [sentence.constituency for sentence in doc.sentences] |
|
|
|