File size: 18,426 Bytes
e8f4897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import from_numpy
from ..nn import Embedding
from ..nn import BiAAttention, BiLinear
from utils.tasks import parse
from ..nn import utils

class BiAffine_Parser(nn.Module):
    def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos, num_filters,
                 kernel_size, rnn_mode, hidden_size, num_layers, num_arcs,
                 arc_space, arc_tag_space, embedd_word=None, embedd_char=None, embedd_pos=None,
                 p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33), biaffine=True, arc_decode='mst',
                 initializer=None):
        super(BiAffine_Parser, self).__init__()
        self.rnn_encoder = BiRecurrentConv_Encoder(word_dim, num_words, char_dim, num_chars, use_pos, use_char,
                                                   pos_dim, num_pos, num_filters,
                                                   kernel_size, rnn_mode, hidden_size,
                                                   num_layers, embedd_word=embedd_word,
                                                   embedd_char=embedd_char, embedd_pos=embedd_pos,
                                                   p_in=p_in, p_out=p_out, p_rnn=p_rnn, initializer=initializer)
        self.parser = BiAffine_Parser_Decoder(hidden_size, num_arcs, arc_space, arc_tag_space, biaffine, p_out, arc_decode)

    def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
        encoder_output, hn, mask, length = self.rnn_encoder(input_word, input_char, input_pos, mask, length, hx)
        out_arc, out_arc_tag = self.parser(encoder_output, mask)
        return out_arc, out_arc_tag, mask, length

    def loss(self, out_arc, out_arc_tag, heads, arc_tags, mask=None, length=None):
        # out_arc shape [batch_size, length, length]
        # out_arc_tag shape [batch_size, length, arc_tag_space]
        loss_arc, loss_arc_tag = self.parser.loss(out_arc, out_arc_tag, heads, arc_tags, mask, length)
        return loss_arc, loss_arc_tag

    def loss_per_sample(self, out_arc, out_arc_tag, heads, arc_tags, mask=None, length=None):
        # out_arc shape [batch_size, length, length]
        # out_arc_tag shape [batch_size, length, arc_tag_space]
        loss_arc, loss_arc_tag = self.parser.loss_per_sample(out_arc, out_arc_tag, heads, arc_tags, mask, length)
        return loss_arc, loss_arc_tag

    def decode(self, out_arc, out_arc_tag, mask=None, length=None, leading_symbolic=0):
        heads_pred, arc_tags_pred, scores = self.parser.decode(out_arc, out_arc_tag, mask, length, leading_symbolic)
        return heads_pred, arc_tags_pred, scores

    def pre_loss(self, out_arc, out_arc_tag, heads, arc_tags, mask=None, length=None, use_log=True, temperature=1.0):
        out_arc, out_arc_tag = self.parser.pre_loss(out_arc, out_arc_tag, heads, arc_tags, mask, length, use_log, temperature)
        return out_arc, out_arc_tag

class BiAffine_Parser_Decoder(nn.Module):
    def __init__(self, hidden_size, num_arcs, arc_space, arc_tag_space, biaffine, p_out, arc_decode):
        super(BiAffine_Parser_Decoder, self).__init__()
        self.num_arcs = num_arcs
        self.arc_space = arc_space
        self.arc_tag_space = arc_tag_space
        self.out_dim = hidden_size * 2
        self.biaffine = biaffine
        self.p_out = p_out
        self.arc_decode = arc_decode
        self.dropout_out = nn.Dropout(self.p_out)
        self.arc_h = nn.Linear(self.out_dim, self.arc_space)
        self.arc_c = nn.Linear(self.out_dim, self.arc_space)
        self.attention = BiAAttention(self.arc_space, self.arc_space, 1, biaffine=biaffine)
        self.arc_tag_h = nn.Linear(self.out_dim, arc_tag_space)
        self.arc_tag_c = nn.Linear(self.out_dim, arc_tag_space)
        self.bilinear = BiLinear(arc_tag_space, arc_tag_space, num_arcs)

    def forward(self, input, mask):
        # apply dropout for output
        # [batch_size, length, hidden_size] --> [batch_size, hidden_size, length] --> [batch_size, length, hidden_size]
        input = self.dropout_out(input.transpose(1, 2)).transpose(1, 2)

        # output size [batch_size, length, arc_space]
        arc_h = F.elu(self.arc_h(input))
        arc_c = F.elu(self.arc_c(input))

        # output size [batch_size, length, arc_tag_space]
        arc_tag_h = F.elu(self.arc_tag_h(input))
        arc_tag_c = F.elu(self.arc_tag_c(input))

        # apply dropout
        # [batch_size, length, dim] --> [batch_size, 2 * length, dim]
        arc = torch.cat([arc_h, arc_c], dim=1)
        arc_tag = torch.cat([arc_tag_h, arc_tag_c], dim=1)

        arc = self.dropout_out(arc.transpose(1, 2)).transpose(1, 2)
        arc_h, arc_c = arc.chunk(2, 1)
        arc_tag = self.dropout_out(arc_tag.transpose(1, 2)).transpose(1, 2)

        # output from rnn [batch_size, length, tag_space]
        arc_tag_h, arc_tag_c = arc_tag.chunk(2, 1)
        # head shape [batch_size, length, arc_tag_space]
        arc_tag_h = arc_tag_h.contiguous()
        # child shape [batch_size, length, arc_tag_space]
        arc_tag_c = arc_tag_c.contiguous()
        arc = (arc_h, arc_c)
        # [batch_size, length, length]
        out_arc = self.attention(arc[0], arc[1], mask_d=mask, mask_e=mask).squeeze(dim=1)
        out_arc_tag = (arc_tag_h, arc_tag_c)
        return out_arc, out_arc_tag

    def loss(self, out_arc, out_arc_tag, heads, arc_tags, mask=None, length=None):
        out_arc, out_arc_tag = self.pre_loss(out_arc, out_arc_tag, heads=heads, arc_tags=arc_tags, mask=mask, length=length, use_log=True, temperature=1.0)
        batch_size, max_len = out_arc.size()
        # loss_arc shape [length-1, batch_size]
        out_arc = out_arc.t()
        # loss_arc_tag shape [length-1, batch_size]
        out_arc_tag = out_arc_tag.t()
        # number of valid positions which contribute to loss (remove the symbolic head for each sentence).
        num = mask.sum() - batch_size if mask is not None else float(max_len) * batch_size
        dp_loss = -out_arc.sum() / num, -out_arc_tag.sum() / num
        return dp_loss

    def decode(self, out_arc, out_arc_tag, mask, length, leading_symbolic):
        if self.arc_decode == 'mst':
            heads, arc_tags, scores = self.decode_mst(out_arc, out_arc_tag, mask, length, leading_symbolic)
        else: #self.arc_decode == 'greedy'
            heads, arc_tags, scores = self.decode_greedy(out_arc, out_arc_tag, mask, leading_symbolic)
        return heads, arc_tags, scores

    def decode_mst(self, out_arc, out_arc_tag, mask, length, leading_symbolic):
        loss_arc, loss_arc_tag = self.pre_loss(out_arc, out_arc_tag, heads=None, arc_tags=None, mask=mask, length=length, use_log=True, temperature=1.0)
        batch_size, max_len, _ = loss_arc.size()
        # compute lengths
        if length is None:
            if mask is None:
                length = [max_len for _ in range(batch_size)]
            else:
                length = mask.data.sum(dim=1).long().cpu().numpy()
        # energy shape [batch_size, num_arcs, length, length]
        energy = torch.exp(loss_arc.unsqueeze(1) + loss_arc_tag)
        heads, arc_tags = parse.decode_MST(energy.data.cpu().numpy(), length, leading_symbolic=leading_symbolic,
                                           labeled=True)
        heads = from_numpy(heads)
        arc_tags = from_numpy(arc_tags)

        # compute the average score for each tree
        batch_size, max_len = heads.size()
        scores = torch.zeros_like(heads, dtype=energy.dtype, device=energy.device)
        for b_idx in range(batch_size):
            for len_idx in range(max_len):
                scores[b_idx, len_idx] = energy[b_idx, arc_tags[b_idx, len_idx], heads[b_idx, len_idx], len_idx]
        if mask is not None:
            scores = scores.sum(1) / mask.sum(1)
        else:
            scores = scores.sum(1) / max_len
        scores = scores.detach()
        return heads, arc_tags, scores

    def decode_greedy(self, out_arc, out_arc_tag, mask, leading_symbolic):
        '''
        Args:
            out_arc: Tensor
                the arc scores with shape [batch_size, length, length]
            out_arc_tag: Tensor
                the labeled arc scores with shape [batch_size, length, arc_tag_space]
            mask: Tensor or None
                the mask tensor with shape = [batch_size, length]
            length: Tensor or None
                the length tensor with shape = [batch_size]
            leading_symbolic: int
                number of symbolic labels leading in arc_tag alphabets (set it to 0 if you are not sure)

        Returns: (Tensor, Tensor)
                predicted heads and arc_tags.
        '''
        def _decode_arc_tags(out_arc_tag, heads, leading_symbolic):
            # out_arc_tag shape [batch_size, length, arc_tag_space]
            arc_tag_h, arc_tag_c = out_arc_tag
            batch_size, max_len, _ = arc_tag_h.size()
            # create batch index [batch_size]
            batch_index = torch.arange(0, batch_size).type_as(arc_tag_h.data).long()
            # get vector for heads [batch_size, length, arc_tag_space],
            arc_tag_h = arc_tag_h[batch_index, heads.t()].transpose(0, 1).contiguous()
            # compute output for arc_tag [batch_size, length, num_arcs]
            out_arc_tag = self.bilinear(arc_tag_h, arc_tag_c)
            # remove the first #leading_symbolic arc_tags.
            out_arc_tag = out_arc_tag[:, :, leading_symbolic:]
            # compute the prediction of arc_tags [batch_size, length]
            _, arc_tags = out_arc_tag.max(dim=2)
            return arc_tags + leading_symbolic

        # out_arc shape [batch_size, length, length]
        out_arc = out_arc.data
        _, max_len, _ = out_arc.size()
        # set diagonal elements to -inf
        out_arc = out_arc + torch.diag(out_arc.new(max_len).fill_(-np.inf))
        # set invalid positions to -inf
        if mask is not None:
            # minus_mask = (1 - mask.data).byte().view(batch_size, max_len, 1)
            minus_mask = (1 - mask.data).byte().unsqueeze(2)
            out_arc.masked_fill_(minus_mask, -np.inf)

        # compute naive predictions.
        # prediction shape = [batch_size, length]
        scores, heads = out_arc.max(dim=1)

        arc_tags = _decode_arc_tags(out_arc_tag, heads, leading_symbolic)
        # compute the average score for each tree
        if mask is not None:
            scores = scores.sum(1) / mask.sum(1)
        else:
            scores = scores.sum(1) / max_len
        return heads, arc_tags, scores

    def pre_loss(self, out_arc, out_arc_tag, heads=None, arc_tags=None, mask=None, length=None, use_log=True, temperature=1.0):
        if (heads is not None and arc_tags is None) or (heads is None and arc_tags is not None):
            raise ValueError('heads and arc_tags should be both Nones or both not Nones')
        decode = True if (heads is None and arc_tags is None) else False
        softmax_func = F.log_softmax if use_log else F.softmax
        # out_arc shape [batch_size, length, length]
        # out_arc_tag shape [batch_size, length, arc_tag_space]
        arc_tag_h, arc_tag_c = out_arc_tag
        batch_size, max_len, arc_tag_space = arc_tag_h.size()
        batch_index = None
        if not decode:
            if length is not None and heads.size(1) != max_len:
                heads = heads[:, :max_len]
                arc_tags = arc_tags[:, :max_len]
            # create batch index [batch_size]
            batch_index = torch.arange(0, batch_size).type_as(out_arc.data).long()
            # get vector for heads [batch_size, length, arc_tag_space],
            arc_tag_h = arc_tag_h[batch_index, heads.data.t()].transpose(0, 1).contiguous()
        else:
            arc_tag_h = arc_tag_h.unsqueeze(2).expand(batch_size, max_len, max_len, arc_tag_space).contiguous()
            arc_tag_c = arc_tag_c.unsqueeze(1).expand(batch_size, max_len, max_len, arc_tag_space).contiguous()

        # compute output for arc_tag [batch_size, length, num_arcs]
        out_arc_tag = self.bilinear(arc_tag_h, arc_tag_c)

        # mask invalid position to -inf for softmax_func
        if mask is not None:
            minus_inf = -1e8
            minus_mask = (1 - mask) * minus_inf
            out_arc = out_arc + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if not decode:
            # loss_arc shape [batch_size, length, length]
            out_arc = softmax_func(out_arc / temperature, dim=1)
            # loss_arc_tag shape [batch_size, length, num_arcs]
            out_arc_tag = softmax_func(out_arc_tag / temperature, dim=2)
            # mask invalid position to 0 for sum loss
            if mask is not None:
                out_arc = out_arc * mask.unsqueeze(2) * mask.unsqueeze(1)
                out_arc_tag = out_arc_tag * mask.unsqueeze(2)

            # first create index matrix [length, batch_size]
            child_index = torch.arange(0, max_len).view(max_len, 1).expand(max_len, batch_size)
            child_index = child_index.type_as(out_arc.data).long()
            # loss_arc shape [batch_size, length-1]
            out_arc = out_arc[batch_index, heads.data.t(), child_index][1:].t()
            # loss_arc_tag shape [batch_size, length-1]
            out_arc_tag = out_arc_tag[batch_index, child_index, arc_tags.data.t()][1:].t()
        else:
            # loss_arc shape [batch_size, length, length]
            out_arc = softmax_func(out_arc / temperature, dim=1)
            # loss_arc_tag shape [batch_size, length, length, num_arcs]
            out_arc_tag = softmax_func(out_arc_tag / temperature, dim=3).permute(0, 3, 1, 2)
        return out_arc, out_arc_tag


class BiRecurrentConv_Encoder(nn.Module):
    def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos, num_filters,
                 kernel_size, rnn_mode, hidden_size, num_layers, embedd_word=None, embedd_char=None, embedd_pos=None,
                 p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33), initializer=None):
        super(BiRecurrentConv_Encoder, self).__init__()
        self.word_embedd = Embedding(num_words, word_dim, init_embedding=embedd_word)
        self.char_embedd = Embedding(num_chars, char_dim, init_embedding=embedd_char) if use_char else None
        self.pos_embedd = Embedding(num_pos, pos_dim, init_embedding=embedd_pos) if use_pos else None
        self.conv1d = nn.Conv1d(char_dim, num_filters, kernel_size, padding=kernel_size - 1) if use_char else None
        # dropout word
        self.dropout_in = nn.Dropout2d(p_in)
        # standard dropout
        self.dropout_out = nn.Dropout2d(p_out)
        self.dropout_rnn_in = nn.Dropout(p_rnn[0])
        self.use_pos = use_pos
        self.use_char = use_char
        self.rnn_mode = rnn_mode
        self.dim_enc = word_dim
        if use_pos:
            self.dim_enc += pos_dim
        if use_char:
            self.dim_enc += num_filters

        if rnn_mode == 'RNN':
            RNN = nn.RNN
            drop_p_rnn = p_rnn[1]
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
            drop_p_rnn = p_rnn[1]
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
            drop_p_rnn = p_rnn[1]
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)
        self.rnn = RNN(self.dim_enc, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True,
                       dropout=drop_p_rnn)
        self.initializer = initializer
        self.reset_parameters()

    def reset_parameters(self):
        if self.initializer is None:
            return

        for name, parameter in self.named_parameters():
            if name.find('embedd') == -1:
                if parameter.dim() == 1:
                    parameter.data.zero_()
                else:
                    self.initializer(parameter.data)

    def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
        # hack length from mask
        # we do not hack mask from length for special reasons.
        # Thus, always provide mask if it is necessary.
        if length is None and mask is not None:
            length = mask.data.sum(dim=1).long()

        # [batch_size, length, word_dim]
        word = self.word_embedd(input_word)
        # apply dropout on input
        word = self.dropout_in(word)

        input = word
        if self.use_char:
            # [batch_size, length, char_length, char_dim]
            char = self.char_embedd(input_char)
            char_size = char.size()
            # first transform to [batch *length, char_length, char_dim]
            # then transpose to [batch * length, char_dim, char_length]
            char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
            # put into cnn [batch*length, char_filters, char_length]
            # then put into maxpooling [batch * length, char_filters]
            char, _ = self.conv1d(char).max(dim=2)
            # reshape to [batch_size, length, char_filters]
            char = torch.tanh(char).view(char_size[0], char_size[1], -1)
            # apply dropout on input
            char = self.dropout_in(char)
            # concatenate word and char [batch_size, length, word_dim+char_filter]
            input = torch.cat([input, char], dim=2)

        if self.use_pos:
            # [batch_size, length, pos_dim]
            pos = self.pos_embedd(input_pos)
            # apply dropout on input
            pos = self.dropout_in(pos)
            input = torch.cat([input, pos], dim=2)

        # apply dropout rnn input
        input = self.dropout_rnn_in(input)
        # prepare packed_sequence
        if length is not None:
            seq_input, hx, rev_order, mask = utils.prepare_rnn_seq(input, length, hx=hx, masks=mask, batch_first=True)
            self.rnn.flatten_parameters()
            seq_output, hn = self.rnn(seq_input, hx=hx)
            output, hn = utils.recover_rnn_seq(seq_output, rev_order, hx=hn, batch_first=True)
        else:
            # output from rnn [batch_size, length, hidden_size]
            self.rnn.flatten_parameters()
            output, hn = self.rnn(input, hx=hx)
        # apply dropout for the output of rnn
        output = self.dropout_out(output)
        return output, hn, mask, length