File size: 5,645 Bytes
366b225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# -*- coding: utf-8 -*-
from parser.modules import CHAR_LSTM, MLP, BertEmbedding, Biaffine, BiLSTM
from parser.modules.dropout import IndependentDropout, SharedDropout
import torch
import torch.nn as nn
from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence,
pad_sequence)
class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
self.args = args
# the embedding layer
self.word_embed = nn.Embedding(num_embeddings=args.n_words,
embedding_dim=args.n_embed)
if args.feat == 'char':
self.feat_embed = CHAR_LSTM(n_chars=args.n_feats,
n_embed=args.n_char_embed,
n_out=args.n_embed)
elif args.feat == 'bert':
self.feat_embed = BertEmbedding(model=args.bert_model,
n_layers=args.n_bert_layers,
n_out=args.n_embed)
else:
self.feat_embed = nn.Embedding(num_embeddings=args.n_feats,
embedding_dim=args.n_embed)
self.embed_dropout = IndependentDropout(p=args.embed_dropout)
# the word-lstm layer
self.lstm = BiLSTM(input_size=args.n_embed*2,
hidden_size=args.n_lstm_hidden,
num_layers=args.n_lstm_layers,
dropout=args.lstm_dropout)
self.lstm_dropout = SharedDropout(p=args.lstm_dropout)
# the MLP layers
self.mlp_arc_h = MLP(n_in=args.n_lstm_hidden*2,
n_hidden=args.n_mlp_arc,
dropout=args.mlp_dropout)
self.mlp_arc_d = MLP(n_in=args.n_lstm_hidden*2,
n_hidden=args.n_mlp_arc,
dropout=args.mlp_dropout)
self.mlp_rel_h = MLP(n_in=args.n_lstm_hidden*2,
n_hidden=args.n_mlp_rel,
dropout=args.mlp_dropout)
self.mlp_rel_d = MLP(n_in=args.n_lstm_hidden*2,
n_hidden=args.n_mlp_rel,
dropout=args.mlp_dropout)
# the Biaffine layers
self.arc_attn = Biaffine(n_in=args.n_mlp_arc,
bias_x=True,
bias_y=False)
self.rel_attn = Biaffine(n_in=args.n_mlp_rel,
n_out=args.n_rels,
bias_x=True,
bias_y=True)
self.pad_index = args.pad_index
self.unk_index = args.unk_index
def load_pretrained(self, embed=None):
if embed is not None:
self.pretrained = nn.Embedding.from_pretrained(embed)
nn.init.zeros_(self.word_embed.weight)
return self
def forward(self, words, feats):
batch_size, seq_len = words.shape
# get the mask and lengths of given batch
mask = words.ne(self.pad_index)
lens = mask.sum(dim=1)
# set the indices larger than num_embeddings to unk_index
ext_mask = words.ge(self.word_embed.num_embeddings)
ext_words = words.masked_fill(ext_mask, self.unk_index)
# get outputs from embedding layers
word_embed = self.word_embed(ext_words)
if hasattr(self, 'pretrained'):
word_embed = torch.cat((word_embed, self.pretrained(words)), dim=2)
if self.args.feat == 'char':
print(mask.shape)
feat_embed = self.feat_embed(feats[mask])
feat_embed = pad_sequence(feat_embed.split(lens.tolist()), True)
elif self.args.feat == 'bert':
feat_embed = self.feat_embed(*feats)
else:
feat_embed = self.feat_embed(feats)
word_embed, feat_embed = self.embed_dropout(word_embed, feat_embed)
# concatenate the word and feat representations
embed = torch.cat((word_embed, feat_embed), dim=-1)
lens = lens.to('cpu')
x = pack_padded_sequence(embed, lens ,True, False)
x, _ = self.lstm(x)
x, _ = pad_packed_sequence(x, True, total_length=seq_len)
x = self.lstm_dropout(x)
# apply MLPs to the BiLSTM output states
arc_h = self.mlp_arc_h(x)
arc_d = self.mlp_arc_d(x)
rel_h = self.mlp_rel_h(x)
rel_d = self.mlp_rel_d(x)
# get arc and rel scores from the bilinear attention
# [batch_size, seq_len, seq_len]
s_arc = self.arc_attn(arc_d, arc_h)
# [batch_size, seq_len, seq_len, n_rels]
s_rel = self.rel_attn(rel_d, rel_h).permute(0, 2, 3, 1)
# set the scores that exceed the length of each sentence to -inf
s_arc.masked_fill_(~mask.unsqueeze(1), float('-inf'))
return s_arc, s_rel
@classmethod
def load(cls, path):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
state = torch.load(path, map_location=device)
model = cls(state['args'])
model.load_pretrained(state['pretrained'])
model.load_state_dict(state['state_dict'], False)
model.to(device)
return model
def save(self, path):
state_dict, pretrained = self.state_dict(), None
if hasattr(self, 'pretrained'):
pretrained = state_dict.pop('pretrained.weight')
state = {
'args': self.args,
'state_dict': state_dict,
'pretrained': pretrained
}
torch.save(state, path)
|