omarmomen commited on
Commit
9f88c8a
·
verified ·
1 Parent(s): 9b430d5

Update structformer.py

Browse files
Files changed (1) hide show
  1. structformer.py +253 -3
structformer.py CHANGED
@@ -18,12 +18,262 @@
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
-
22
- import layers
23
-
24
  from transformers import PretrainedConfig, PreTrainedModel
25
  from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def cumprod(x, reverse=False, exclusive=False):
29
  """cumulative product."""
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
+ from torch.nn import init
 
 
22
  from transformers import PretrainedConfig, PreTrainedModel
23
  from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
24
 
25
+ def _get_activation_fn(activation):
26
+ """Get specified activation function."""
27
+ if activation == "relu":
28
+ return nn.ReLU()
29
+ elif activation == "gelu":
30
+ return nn.GELU()
31
+ elif activation == "leakyrelu":
32
+ return nn.LeakyReLU()
33
+
34
+ raise RuntimeError(
35
+ "activation should be relu/gelu, not {}".format(activation))
36
+
37
+
38
+ class Conv1d(nn.Module):
39
+ """1D convolution layer."""
40
+
41
+ def __init__(self, hidden_size, kernel_size, dilation=1):
42
+ """Initialization.
43
+
44
+ Args:
45
+ hidden_size: dimension of input embeddings
46
+ kernel_size: convolution kernel size
47
+ dilation: the spacing between the kernel points
48
+ """
49
+ super(Conv1d, self).__init__()
50
+
51
+ if kernel_size % 2 == 0:
52
+ padding = (kernel_size // 2) * dilation
53
+ self.shift = True
54
+ else:
55
+ padding = ((kernel_size - 1) // 2) * dilation
56
+ self.shift = False
57
+ self.conv = nn.Conv1d(
58
+ hidden_size,
59
+ hidden_size,
60
+ kernel_size,
61
+ padding=padding,
62
+ dilation=dilation)
63
+
64
+ def forward(self, x):
65
+ """Compute convolution.
66
+
67
+ Args:
68
+ x: input embeddings
69
+ Returns:
70
+ conv_output: convolution results
71
+ """
72
+
73
+ if self.shift:
74
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
75
+ else:
76
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
77
+
78
+
79
+ class MultiheadAttention(nn.Module):
80
+ """Multi-head self-attention layer."""
81
+
82
+ def __init__(self,
83
+ embed_dim,
84
+ num_heads,
85
+ dropout=0.,
86
+ bias=True,
87
+ v_proj=True,
88
+ out_proj=True,
89
+ relative_bias=True):
90
+ """Initialization.
91
+
92
+ Args:
93
+ embed_dim: dimension of input embeddings
94
+ num_heads: number of self-attention heads
95
+ dropout: dropout rate
96
+ bias: bool, indicate whether include bias for linear transformations
97
+ v_proj: bool, indicate whether project inputs to new values
98
+ out_proj: bool, indicate whether project outputs to new values
99
+ relative_bias: bool, indicate whether use a relative position based
100
+ attention bias
101
+ """
102
+
103
+ super(MultiheadAttention, self).__init__()
104
+ self.embed_dim = embed_dim
105
+
106
+ self.num_heads = num_heads
107
+ self.drop = nn.Dropout(dropout)
108
+ self.head_dim = embed_dim // num_heads
109
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
110
+ "divisible by "
111
+ "num_heads")
112
+
113
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
114
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
115
+ if v_proj:
116
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
117
+ else:
118
+ self.v_proj = nn.Identity()
119
+
120
+ if out_proj:
121
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
122
+ else:
123
+ self.out_proj = nn.Identity()
124
+
125
+ if relative_bias:
126
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
127
+ else:
128
+ self.relative_bias = None
129
+
130
+ self._reset_parameters()
131
+
132
+ def _reset_parameters(self):
133
+ """Initialize attention parameters."""
134
+
135
+ init.xavier_uniform_(self.q_proj.weight)
136
+ init.constant_(self.q_proj.bias, 0.)
137
+
138
+ init.xavier_uniform_(self.k_proj.weight)
139
+ init.constant_(self.k_proj.bias, 0.)
140
+
141
+ if isinstance(self.v_proj, nn.Linear):
142
+ init.xavier_uniform_(self.v_proj.weight)
143
+ init.constant_(self.v_proj.bias, 0.)
144
+
145
+ if isinstance(self.out_proj, nn.Linear):
146
+ init.xavier_uniform_(self.out_proj.weight)
147
+ init.constant_(self.out_proj.bias, 0.)
148
+
149
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
150
+ """Compute multi-head self-attention.
151
+
152
+ Args:
153
+ query: input embeddings
154
+ key_padding_mask: 3D mask that prevents attention to certain positions
155
+ attn_mask: 3D mask that rescale the attention weight at each position
156
+ Returns:
157
+ attn_output: self-attention output
158
+ """
159
+
160
+ length, bsz, embed_dim = query.size()
161
+ assert embed_dim == self.embed_dim
162
+
163
+ head_dim = embed_dim // self.num_heads
164
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
165
+ "divisible by num_heads")
166
+ scaling = float(head_dim)**-0.5
167
+
168
+ q = self.q_proj(query)
169
+ k = self.k_proj(query)
170
+ v = self.v_proj(query)
171
+
172
+ q = q * scaling
173
+
174
+ if attn_mask is not None:
175
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
176
+ query.size(0), query.size(0)]
177
+
178
+ q = q.contiguous().view(length, bsz * self.num_heads,
179
+ head_dim).transpose(0, 1)
180
+ k = k.contiguous().view(length, bsz * self.num_heads,
181
+ head_dim).transpose(0, 1)
182
+ v = v.contiguous().view(length, bsz * self.num_heads,
183
+ head_dim).transpose(0, 1)
184
+
185
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
186
+ assert list(
187
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
188
+
189
+ if self.relative_bias is not None:
190
+ pos = torch.arange(length, device=query.device)
191
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
192
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
193
+ -1)
194
+
195
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
196
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
197
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
198
+ attn_output_weights = attn_output_weights + relative_bias
199
+
200
+ if key_padding_mask is not None:
201
+ attn_output_weights = attn_output_weights + key_padding_mask
202
+
203
+ if attn_mask is None:
204
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
205
+ else:
206
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
207
+
208
+ attn_output_weights = self.drop(attn_output_weights)
209
+
210
+ attn_output = torch.bmm(attn_output_weights, v)
211
+
212
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
213
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
214
+ length, bsz, embed_dim)
215
+ attn_output = self.out_proj(attn_output)
216
+
217
+ return attn_output
218
+
219
+
220
+ class TransformerLayer(nn.Module):
221
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
222
+
223
+ def __init__(self,
224
+ d_model,
225
+ nhead,
226
+ dim_feedforward=2048,
227
+ dropout=0.1,
228
+ dropatt=0.1,
229
+ activation="leakyrelu",
230
+ relative_bias=True):
231
+ """Initialization.
232
+
233
+ Args:
234
+ d_model: dimension of inputs
235
+ nhead: number of self-attention heads
236
+ dim_feedforward: dimension of hidden layer in feedforward layer
237
+ dropout: dropout rate
238
+ dropatt: drop attention rate
239
+ activation: activation function
240
+ relative_bias: bool, indicate whether use a relative position based
241
+ attention bias
242
+ """
243
+
244
+ super(TransformerLayer, self).__init__()
245
+ self.self_attn = MultiheadAttention(
246
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
247
+ # Implementation of Feedforward model
248
+ self.feedforward = nn.Sequential(
249
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
250
+ _get_activation_fn(activation), nn.Dropout(dropout),
251
+ nn.Linear(dim_feedforward, d_model))
252
+
253
+ self.norm = nn.LayerNorm(d_model)
254
+ self.dropout1 = nn.Dropout(dropout)
255
+ self.dropout2 = nn.Dropout(dropout)
256
+
257
+ self.nhead = nhead
258
+
259
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
260
+ """Pass the input through the encoder layer.
261
+
262
+ Args:
263
+ src: the sequence to the encoder layer (required).
264
+ attn_mask: the mask for the src sequence (optional).
265
+ key_padding_mask: the mask for the src keys per batch (optional).
266
+ Returns:
267
+ src3: the output of transformer layer, share the same shape as src.
268
+ """
269
+ src2 = self.self_attn(
270
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
271
+ src2 = src + self.dropout1(src2)
272
+ src3 = self.feedforward(src2)
273
+ src3 = src2 + self.dropout2(src3)
274
+
275
+ return src3
276
+
277
 
278
  def cumprod(x, reverse=False, exclusive=False):
279
  """cumulative product."""