|
|
|
|
|
|
|
|
from parser.modules import CHAR_LSTM, MLP, BertEmbedding, Biaffine, BiLSTM |
|
|
from parser.modules.dropout import IndependentDropout, SharedDropout |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence, |
|
|
pad_sequence) |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
|
|
def __init__(self, args): |
|
|
super(Model, self).__init__() |
|
|
|
|
|
self.args = args |
|
|
|
|
|
self.word_embed = nn.Embedding(num_embeddings=args.n_words, |
|
|
embedding_dim=args.n_embed) |
|
|
if args.feat == 'char': |
|
|
self.feat_embed = CHAR_LSTM(n_chars=args.n_feats, |
|
|
n_embed=args.n_char_embed, |
|
|
n_out=args.n_embed) |
|
|
elif args.feat == 'bert': |
|
|
self.feat_embed = BertEmbedding(model=args.bert_model, |
|
|
n_layers=args.n_bert_layers, |
|
|
n_out=args.n_embed) |
|
|
else: |
|
|
self.feat_embed = nn.Embedding(num_embeddings=args.n_feats, |
|
|
embedding_dim=args.n_embed) |
|
|
self.embed_dropout = IndependentDropout(p=args.embed_dropout) |
|
|
|
|
|
|
|
|
self.lstm = BiLSTM(input_size=args.n_embed*2, |
|
|
hidden_size=args.n_lstm_hidden, |
|
|
num_layers=args.n_lstm_layers, |
|
|
dropout=args.lstm_dropout) |
|
|
self.lstm_dropout = SharedDropout(p=args.lstm_dropout) |
|
|
|
|
|
|
|
|
self.mlp_arc_h = MLP(n_in=args.n_lstm_hidden*2, |
|
|
n_hidden=args.n_mlp_arc, |
|
|
dropout=args.mlp_dropout) |
|
|
self.mlp_arc_d = MLP(n_in=args.n_lstm_hidden*2, |
|
|
n_hidden=args.n_mlp_arc, |
|
|
dropout=args.mlp_dropout) |
|
|
self.mlp_rel_h = MLP(n_in=args.n_lstm_hidden*2, |
|
|
n_hidden=args.n_mlp_rel, |
|
|
dropout=args.mlp_dropout) |
|
|
self.mlp_rel_d = MLP(n_in=args.n_lstm_hidden*2, |
|
|
n_hidden=args.n_mlp_rel, |
|
|
dropout=args.mlp_dropout) |
|
|
|
|
|
|
|
|
self.arc_attn = Biaffine(n_in=args.n_mlp_arc, |
|
|
bias_x=True, |
|
|
bias_y=False) |
|
|
self.rel_attn = Biaffine(n_in=args.n_mlp_rel, |
|
|
n_out=args.n_rels, |
|
|
bias_x=True, |
|
|
bias_y=True) |
|
|
self.pad_index = args.pad_index |
|
|
self.unk_index = args.unk_index |
|
|
|
|
|
def load_pretrained(self, embed=None): |
|
|
if embed is not None: |
|
|
self.pretrained = nn.Embedding.from_pretrained(embed) |
|
|
nn.init.zeros_(self.word_embed.weight) |
|
|
|
|
|
return self |
|
|
|
|
|
def forward(self, words, feats): |
|
|
batch_size, seq_len = words.shape |
|
|
|
|
|
mask = words.ne(self.pad_index) |
|
|
lens = mask.sum(dim=1) |
|
|
|
|
|
ext_mask = words.ge(self.word_embed.num_embeddings) |
|
|
ext_words = words.masked_fill(ext_mask, self.unk_index) |
|
|
|
|
|
|
|
|
word_embed = self.word_embed(ext_words) |
|
|
if hasattr(self, 'pretrained'): |
|
|
word_embed = torch.cat((word_embed, self.pretrained(words)), dim=2) |
|
|
if self.args.feat == 'char': |
|
|
print(mask.shape) |
|
|
feat_embed = self.feat_embed(feats[mask]) |
|
|
feat_embed = pad_sequence(feat_embed.split(lens.tolist()), True) |
|
|
elif self.args.feat == 'bert': |
|
|
feat_embed = self.feat_embed(*feats) |
|
|
else: |
|
|
feat_embed = self.feat_embed(feats) |
|
|
word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed) |
|
|
|
|
|
embed = torch.cat((word_embed, feat_embed), dim=-1) |
|
|
|
|
|
lens = lens.to('cpu') |
|
|
x = pack_padded_sequence(embed, lens ,True, False) |
|
|
x, _ = self.lstm(x) |
|
|
x, _ = pad_packed_sequence(x, True, total_length=seq_len) |
|
|
x = self.lstm_dropout(x) |
|
|
|
|
|
|
|
|
arc_h = self.mlp_arc_h(x) |
|
|
arc_d = self.mlp_arc_d(x) |
|
|
rel_h = self.mlp_rel_h(x) |
|
|
rel_d = self.mlp_rel_d(x) |
|
|
|
|
|
|
|
|
|
|
|
s_arc = self.arc_attn(arc_d, arc_h) |
|
|
|
|
|
s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1) |
|
|
|
|
|
s_arc.masked_fill_(~mask.unsqueeze(1), float('-inf')) |
|
|
|
|
|
return s_arc, s_rel |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path): |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
state = torch.load(path, map_location=device) |
|
|
model = cls(state['args']) |
|
|
model.load_pretrained(state['pretrained']) |
|
|
model.load_state_dict(state['state_dict'], False) |
|
|
model.to(device) |
|
|
|
|
|
return model |
|
|
|
|
|
def save(self, path): |
|
|
state_dict, pretrained = self.state_dict(), None |
|
|
if hasattr(self, 'pretrained'): |
|
|
pretrained = state_dict.pop('pretrained.weight') |
|
|
state = { |
|
|
'args': self.args, |
|
|
'state_dict': state_dict, |
|
|
'pretrained': pretrained |
|
|
} |
|
|
torch.save(state, path) |
|
|
|