""" A module to use a Constituency Parser to make an embedding for a tree The embedding can be produced just from the words and the top of the tree, or it can be done with a form of attention over the nodes Can be done over an existing parse tree or unparsed text """ import torch import torch.nn as nn from stanza.models.constituency.trainer import Trainer class TreeEmbedding(nn.Module): def __init__(self, constituency_parser, args): super(TreeEmbedding, self).__init__() self.config = { "all_words": args["all_words"], "backprop": args["backprop"], #"batch_norm": args["batch_norm"], "node_attn": args["node_attn"], "top_layer": args["top_layer"], } self.constituency_parser = constituency_parser # word_lstm: hidden_size * num_tree_lstm_layers * 2 (start & end) # transition_stack: transition_hidden_size # constituent_stack: hidden_size self.hidden_size = self.constituency_parser.hidden_size + self.constituency_parser.transition_hidden_size if self.config["all_words"]: self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers else: self.hidden_size += self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers * 2 if self.config["node_attn"]: self.query = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size) self.key = nn.Linear(self.hidden_size, self.constituency_parser.hidden_size) self.value = nn.Linear(self.constituency_parser.hidden_size, self.constituency_parser.hidden_size) # TODO: cat transition and constituent hx as well? self.output_size = self.constituency_parser.hidden_size * self.constituency_parser.num_tree_lstm_layers else: self.output_size = self.hidden_size # TODO: maybe have batch_norm, maybe use Identity #if self.config["batch_norm"]: # self.input_norm = nn.BatchNorm1d(self.output_size) def embed_trees(self, inputs): if self.config["backprop"]: states = self.constituency_parser.analyze_trees(inputs) else: with torch.no_grad(): states = self.constituency_parser.analyze_trees(inputs) constituent_lists = [x.constituents for x in states] states = [x.state for x in states] word_begin_hx = torch.stack([state.word_queue[0].hx for state in states]) word_end_hx = torch.stack([state.word_queue[state.word_position].hx for state in states]) transition_hx = torch.stack([self.constituency_parser.transition_stack.output(state.transitions) for state in states]) # go down one layer to get the embedding off the top of the S, not the ROOT # (in terms of the typical treebank) # the idea being that the ROOT has no additional information # and may even have 0s for the embedding in certain circumstances, # such as after learning UNTIED_MAX long enough if self.config["top_layer"]: constituent_hx = torch.stack([self.constituency_parser.constituent_stack.output(state.constituents) for state in states]) else: constituent_hx = torch.cat([constituents[-2].tree_hx for constituents in constituent_lists], dim=0) if self.config["all_words"]: # need B matrices of N x hidden_size key = [torch.stack([torch.cat([word.hx, thx, chx]) for word in state.word_queue], dim=0) for state, thx, chx in zip(states, transition_hx, constituent_hx)] else: key = torch.cat((word_begin_hx, word_end_hx, transition_hx, constituent_hx), dim=1).unsqueeze(1) if not self.config["node_attn"]: return key key = [self.key(x) for x in key] node_hx = [torch.stack([con.tree_hx for con in constituents], dim=0) for constituents in constituent_lists] queries = [self.query(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx] values = [self.value(nhx).reshape(nhx.shape[0], -1) for nhx in node_hx] # TODO: could pad to make faster here attn = [torch.matmul(q, k.transpose(0, 1)) for q, k in zip(queries, key)] attn = [torch.softmax(x, dim=0) for x in attn] previous_layer = [torch.matmul(weight.transpose(0, 1), value) for weight, value in zip(attn, values)] return previous_layer def forward(self, inputs): return embed_trees(self, inputs) def get_norms(self): lines = ["constituency_parser." + x for x in self.constituency_parser.get_norms()] for name, param in self.named_parameters(): if param.requires_grad and not name.startswith('constituency_parser.'): lines.append("%s %.6g" % (name, torch.norm(param).item())) return lines def get_params(self, skip_modules=True): model_state = self.state_dict() # skip all of the constituency parameters here - # we will add them by calling the model's get_params() skipped = [k for k in model_state.keys() if k.startswith("constituency_parser.")] for k in skipped: del model_state[k] parser = self.constituency_parser.get_params(skip_modules) params = { 'model': model_state, 'constituency': parser, 'config': self.config, } return params @staticmethod def from_parser_file(args, foundation_cache=None): constituency_parser = Trainer.load(args['model'], args, foundation_cache) return TreeEmbedding(constituency_parser.model, args) @staticmethod def model_from_params(params, args, foundation_cache=None): # TODO: integrate with peft constituency_parser = Trainer.model_from_params(params['constituency'], None, args, foundation_cache) model = TreeEmbedding(constituency_parser, params['config']) model.load_state_dict(params['model'], strict=False) return model