Tingxie commited on
Commit
c8bfe50
·
1 Parent(s): 09db96e

Upload 10 files

Browse files
nn_utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .nn_utils import *
2
+ from .transformer_layers import *
nn_utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (202 Bytes). View file
 
nn_utils/__pycache__/form_embedder.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
nn_utils/__pycache__/nn_utils.cpython-38.pyc ADDED
Binary file (2.87 kB). View file
 
nn_utils/__pycache__/transformer_layers.cpython-38.pyc ADDED
Binary file (20.2 kB). View file
 
nn_utils/base_hyperopt.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ base_hyperopt.py
2
+
3
+ Abstract away common hyperopt functionality
4
+
5
+ """
6
+ import logging
7
+ import yaml
8
+ from pathlib import Path
9
+ from datetime import datetime
10
+ from typing import Callable
11
+
12
+ import pytorch_lightning as pl
13
+
14
+ import ray
15
+ from ray import tune
16
+ from ray.air.config import RunConfig
17
+ from ray.tune.search import ConcurrencyLimiter
18
+ from ray.tune.search.optuna import OptunaSearch
19
+ from ray.tune.schedulers.async_hyperband import ASHAScheduler
20
+
21
+ import mist_cf.common as common
22
+
23
+
24
+ def add_hyperopt_args(parser):
25
+ # Tune args
26
+ ha = parser.add_argument_group("Hyperopt Args")
27
+ ha.add_argument("--cpus-per-trial", default=1, type=int)
28
+ ha.add_argument("--gpus-per-trial", default=1, type=float)
29
+ ha.add_argument("--num-h-samples", default=50, type=int)
30
+ ha.add_argument("--grace-period", default=60 * 15, type=int)
31
+ ha.add_argument("--max-concurrent", default=10, type=int)
32
+ ha.add_argument("--tune-checkpoint", default=None)
33
+
34
+ # Overwrite default savedir
35
+ time_name = datetime.now().strftime("%Y_%m_%d")
36
+ save_default = f"results/{time_name}_hyperopt/"
37
+ parser.set_defaults(save_dir=save_default)
38
+
39
+
40
+ def run_hyperopt(
41
+ kwargs: dict,
42
+ score_function: Callable,
43
+ param_space_function: Callable,
44
+ initial_points: list,
45
+ gen_shared_data: Callable = lambda params: {},
46
+ ):
47
+ """run_hyperopt.
48
+
49
+ Args:
50
+ kwargs: All dictionary args for hyperopt and train
51
+ score_function: Trainable function that sets up model train
52
+ param_space_function: Function to suggest new params
53
+ initial_points: List of initial params to try
54
+ """
55
+ # init ray with new session
56
+ ray.init(address="local")
57
+
58
+ # Fix base_args based upon tune args
59
+ kwargs["gpu"] = kwargs.get("gpus_per_trial", 0) > 0
60
+ # max_t = args.max_epochs
61
+
62
+ if kwargs["debug"]:
63
+ kwargs["num_h_samples"] = 10
64
+ kwargs["max_epochs"] = 5
65
+
66
+ save_dir = kwargs["save_dir"]
67
+ common.setup_logger(
68
+ save_dir, log_name="hyperopt.log", debug=kwargs.get("debug", False)
69
+ )
70
+ pl.utilities.seed.seed_everything(kwargs.get("seed"))
71
+
72
+ shared_args = gen_shared_data(kwargs)
73
+
74
+ # Define score function
75
+ trainable = tune.with_parameters(
76
+ score_function, base_args=kwargs, orig_dir=Path().resolve(), **shared_args
77
+ )
78
+
79
+ # Dump args
80
+ yaml_args = yaml.dump(kwargs)
81
+ logging.info(f"\n{yaml_args}")
82
+ with open(Path(save_dir) / "args.yaml", "w") as fp:
83
+ fp.write(yaml_args)
84
+
85
+ metric = "val_loss"
86
+
87
+ # Include cpus and gpus per trial
88
+ trainable = tune.with_resources(
89
+ trainable,
90
+ resources=tune.PlacementGroupFactory(
91
+ [
92
+ {
93
+ "CPU": kwargs.get("cpus_per_trial"),
94
+ "GPU": kwargs.get("gpus_per_trial"),
95
+ },
96
+ {
97
+ "CPU": kwargs.get("num_workers"),
98
+ },
99
+ ],
100
+ strategy="PACK",
101
+ ),
102
+ )
103
+
104
+ search_algo = OptunaSearch(
105
+ metric=metric,
106
+ mode="min",
107
+ points_to_evaluate=initial_points,
108
+ space=param_space_function,
109
+ )
110
+ search_algo = ConcurrencyLimiter(
111
+ search_algo, max_concurrent=kwargs["max_concurrent"]
112
+ )
113
+
114
+ tuner = tune.Tuner(
115
+ trainable,
116
+ tune_config=tune.TuneConfig(
117
+ mode="min",
118
+ metric=metric,
119
+ search_alg=search_algo,
120
+ scheduler=ASHAScheduler(
121
+ max_t=24 * 60 * 60, # max_t,
122
+ time_attr="time_total_s",
123
+ grace_period=kwargs.get("grace_period"),
124
+ reduction_factor=2,
125
+ ),
126
+ num_samples=kwargs.get("num_h_samples"),
127
+ ),
128
+ run_config=RunConfig(name=None, local_dir=kwargs["save_dir"]),
129
+ )
130
+
131
+ if kwargs.get("tune_checkpoint") is not None:
132
+ ckpt = str(Path(kwargs["tune_checkpoint"]).resolve())
133
+ tuner = tuner.restore(path=ckpt, restart_errored=True)
134
+
135
+ results = tuner.fit()
136
+ best_trial = results.get_best_result()
137
+ output = {"score": best_trial.metrics[metric], "config": best_trial.config}
138
+ out_str = yaml.dump(output, indent=2)
139
+ logging.info(out_str)
140
+ with open(Path(save_dir) / "best_trial.yaml", "w") as f:
141
+ f.write(out_str)
142
+
143
+ # Output full res table
144
+ results.get_dataframe().to_csv(
145
+ Path(save_dir) / "full_res_tbl.tsv", sep="\t", index=None
146
+ )
nn_utils/form_embedder.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ import mist_cf.common as common
6
+
7
+
8
+ class IntFeaturizer(nn.Module):
9
+ """
10
+ Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs
11
+ processing integers).
12
+
13
+ Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that
14
+ integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`.
15
+
16
+ Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra
17
+ "learned" embeddings these will be concatenated after the integer embeddings in the forward pass,
18
+ be learned, and be used for extra non-integer tokens such as the "to be confirmed token" (i.e., pad) token.
19
+ They are indexed starting from `self.MAX_COUNT_INT`.
20
+ """
21
+
22
+ MAX_COUNT_INT = 255 # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1
23
+ NUM_EXTRA_EMBEDDINGS = 1 # Number of extra embeddings to learn -- one for the "to be confirmed" embedding.
24
+
25
+ def __init__(self, embedding_dim):
26
+ super().__init__()
27
+ weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim)
28
+ self._extra_embeddings = nn.Parameter(weights, requires_grad=True)
29
+ nn.init.normal_(self._extra_embeddings, 0.0, 1.0)
30
+ self.embedding_dim = embedding_dim
31
+
32
+ def forward(self, tensor):
33
+ """
34
+ Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
35
+ """
36
+ # todo(jab): copied this code from the original in-built binarizer embedder in built into the class.
37
+ # very similar to F.embedding but we want to put the embedding into the final dimension -- could ask Sam
38
+ # why...
39
+
40
+ orig_shape = tensor.shape
41
+ out_tensor = torch.empty(
42
+ (*orig_shape, self.embedding_dim), device=tensor.device
43
+ )
44
+ extra_embed = tensor >= self.MAX_COUNT_INT
45
+
46
+ tensor = tensor.long()
47
+ norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]]
48
+ extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT]
49
+
50
+ out_tensor[~extra_embed] = norm_embeds
51
+ out_tensor[extra_embed] = extra_embeds
52
+
53
+ temp_out = out_tensor.reshape(*orig_shape[:-1], -1)
54
+ return temp_out
55
+
56
+ @property
57
+ def num_dim(self):
58
+ return self.int_to_feat_matrix.shape[1]
59
+
60
+ @property
61
+ def full_dim(self):
62
+ return self.num_dim * common.NORM_VEC.shape[0]
63
+
64
+
65
+ class Binarizer(IntFeaturizer):
66
+ def __init__(self):
67
+ super().__init__(embedding_dim=len(common.num_to_binary(0)))
68
+ int_to_binary_repr = np.vstack(
69
+ [common.num_to_binary(i) for i in range(self.MAX_COUNT_INT)]
70
+ )
71
+ int_to_binary_repr = torch.from_numpy(int_to_binary_repr)
72
+ self.int_to_feat_matrix = nn.Parameter(int_to_binary_repr.float())
73
+ self.int_to_feat_matrix.requires_grad = False
74
+
75
+
76
+ class FourierFeaturizer(IntFeaturizer):
77
+ """
78
+ Inspired by:
79
+ Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
80
+ Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
81
+ Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
82
+
83
+ Some notes:
84
+ * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
85
+ Binarizer quite closely but be a bit smoother.
86
+ """
87
+
88
+ def __init__(self):
89
+
90
+ num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
91
+ # ^ need at least this many to ensure that the whole input range can be represented on the half circle.
92
+
93
+ freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32)
94
+ freqs_time_2pi = 2 * np.pi * freqs
95
+
96
+ super().__init__(
97
+ embedding_dim=2 * freqs_time_2pi.shape[0]
98
+ ) # 2 for cosine and sine
99
+
100
+ # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
101
+ combo_of_sinusoid_args = (
102
+ torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
103
+ * freqs_time_2pi[None, :]
104
+ )
105
+ all_features = torch.cat(
106
+ [torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)],
107
+ dim=1,
108
+ )
109
+
110
+ # ^ shape: MAX_COUNT_INT x 2 * num_freqs
111
+ self.int_to_feat_matrix = nn.Parameter(all_features.float())
112
+ self.int_to_feat_matrix.requires_grad = False
113
+
114
+
115
+ class FourierFeaturizerSines(IntFeaturizer):
116
+ """
117
+ Like other fourier feats but sines only
118
+
119
+ Inspired by:
120
+ Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
121
+ Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
122
+ Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
123
+
124
+ Some notes:
125
+ * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
126
+ Binarizer quite closely but be a bit smoother.
127
+ """
128
+
129
+ def __init__(self):
130
+
131
+ num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
132
+ # ^ need at least this many to ensure that the whole input range can be represented on the half circle.
133
+
134
+ freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
135
+ freqs_time_2pi = 2 * np.pi * freqs
136
+
137
+ super().__init__(embedding_dim=freqs_time_2pi.shape[0])
138
+
139
+ # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
140
+ combo_of_sinusoid_args = (
141
+ torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
142
+ * freqs_time_2pi[None, :]
143
+ )
144
+ # ^ shape: MAX_COUNT_INT x 2 * num_freqs
145
+ self.int_to_feat_matrix = nn.Parameter(
146
+ torch.sin(combo_of_sinusoid_args).float()
147
+ )
148
+ self.int_to_feat_matrix.requires_grad = False
149
+
150
+
151
+ class FourierFeaturizerAbsoluteSines(IntFeaturizer):
152
+ """
153
+ Like other fourier feats but sines only and absoluted.
154
+
155
+ Inspired by:
156
+ Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
157
+ Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
158
+ Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
159
+
160
+ Some notes:
161
+ * we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
162
+ Binarizer quite closely but be a bit smoother.
163
+ """
164
+
165
+ def __init__(self):
166
+
167
+ num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
168
+
169
+ freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
170
+ freqs_time_2pi = 2 * np.pi * freqs
171
+
172
+ super().__init__(embedding_dim=freqs_time_2pi.shape[0])
173
+
174
+ # we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
175
+ combo_of_sinusoid_args = (
176
+ torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
177
+ * freqs_time_2pi[None, :]
178
+ )
179
+ # ^ shape: MAX_COUNT_INT x 2 * num_freqs
180
+ self.int_to_feat_matrix = nn.Parameter(
181
+ torch.abs(torch.sin(combo_of_sinusoid_args)).float()
182
+ )
183
+ self.int_to_feat_matrix.requires_grad = False
184
+
185
+
186
+ class RBFFeaturizer(IntFeaturizer):
187
+ """
188
+ A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of
189
+ (max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func.
190
+
191
+ """
192
+
193
+ def __init__(self, num_funcs=32):
194
+ """
195
+ :param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class
196
+ docstring.
197
+ """
198
+ super().__init__(embedding_dim=num_funcs)
199
+ width = (self.MAX_COUNT_INT - 1) / num_funcs
200
+ centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs)
201
+
202
+ pre_exponential_terms = (
203
+ -0.5
204
+ * ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width)
205
+ ** 2
206
+ )
207
+ # ^ shape: MAX_COUNT_INT x num_funcs
208
+ feats = torch.exp(pre_exponential_terms)
209
+
210
+ self.int_to_feat_matrix = nn.Parameter(feats.float())
211
+ self.int_to_feat_matrix.requires_grad = False
212
+
213
+
214
+ class OneHotFeaturizer(IntFeaturizer):
215
+ """
216
+ A featurizer that turns integers into their one hot encoding.
217
+
218
+ Represents:
219
+ - 0 as 1000000000...
220
+ - 1 as 0100000000...
221
+ - 2 as 0010000000...
222
+ and so on.
223
+ """
224
+
225
+ def __init__(self):
226
+ super().__init__(embedding_dim=self.MAX_COUNT_INT)
227
+ feats = torch.eye(self.MAX_COUNT_INT)
228
+ self.int_to_feat_matrix = nn.Parameter(feats.float())
229
+ self.int_to_feat_matrix.requires_grad = False
230
+
231
+
232
+ class LearnedFeaturizer(IntFeaturizer):
233
+ """
234
+ Learns the features for the different integers.
235
+
236
+ Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently.
237
+ """
238
+
239
+ def __init__(self, feature_dim=32):
240
+ super().__init__(embedding_dim=feature_dim)
241
+ weights = torch.zeros(self.MAX_COUNT_INT, feature_dim)
242
+ self.int_to_feat_matrix = nn.Parameter(weights, requires_grad=True)
243
+ nn.init.normal_(self.int_to_feat_matrix, 0.0, 1.0)
244
+
245
+
246
+ class FloatFeaturizer(IntFeaturizer):
247
+ """
248
+ Norms the features
249
+ """
250
+
251
+ def __init__(self):
252
+ # Norm vec
253
+ # Placeholder..
254
+ super().__init__(embedding_dim=1)
255
+ self.norm_vec = torch.from_numpy(common.NORM_VEC).float()
256
+ self.norm_vec = nn.Parameter(self.norm_vec)
257
+ self.norm_vec.requires_grad = False
258
+
259
+ def forward(self, tensor):
260
+ """
261
+ Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
262
+ """
263
+ tens_shape = tensor.shape
264
+ out_shape = [1] * (len(tens_shape) - 1) + [-1]
265
+ return tensor / self.norm_vec.reshape(*out_shape)
266
+
267
+ @property
268
+ def num_dim(self):
269
+ return 1
270
+
271
+
272
+ def get_embedder(embedder):
273
+ if embedder == "binary":
274
+ embedder = Binarizer()
275
+ elif embedder == "fourier":
276
+ embedder = FourierFeaturizer()
277
+ elif embedder == "rbf":
278
+ embedder = RBFFeaturizer()
279
+ elif embedder == "one-hot":
280
+ embedder = OneHotFeaturizer()
281
+ elif embedder == "learnt":
282
+ embedder = LearnedFeaturizer()
283
+ elif embedder == "float":
284
+ embedder = FloatFeaturizer()
285
+ elif embedder == "fourier-sines":
286
+ embedder = FourierFeaturizerSines()
287
+ elif embedder == "abs-sines":
288
+ embedder = FourierFeaturizerAbsoluteSines()
289
+ else:
290
+ raise NotImplementedError
291
+ return embedder
nn_utils/gitkeep.txt ADDED
File without changes
nn_utils/nn_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ nn_utils.py
2
+ """
3
+ import math
4
+ import copy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def build_lr_scheduler(
11
+ optimizer, lr_decay_rate: float, decay_steps: int = 5000, warmup: int = 100
12
+ ):
13
+ """build_lr_scheduler.
14
+
15
+ Args:
16
+ optimizer:
17
+ lr_decay_rate (float): lr_decay_rate
18
+ decay_steps (int): decay_steps
19
+ warmup_steps (int): warmup_steps
20
+ """
21
+
22
+ def lr_lambda(step):
23
+ if step >= warmup:
24
+ # Adjust
25
+ step = step - warmup
26
+ rate = lr_decay_rate ** (step // decay_steps)
27
+ else:
28
+ rate = 1 - math.exp(-step / warmup)
29
+ return rate
30
+
31
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
32
+ return scheduler
33
+
34
+
35
+ class MLPBlocks(nn.Module):
36
+ def __init__(
37
+ self,
38
+ input_size: int,
39
+ hidden_size: int,
40
+ dropout: float,
41
+ num_layers: int,
42
+ ):
43
+ super().__init__()
44
+ self.activation = nn.ReLU()
45
+ self.dropout_layer = nn.Dropout(p=dropout)
46
+ self.input_layer = nn.Linear(input_size, hidden_size)
47
+ middle_layer = nn.Linear(hidden_size, hidden_size)
48
+ self.layers = get_clones(middle_layer, num_layers - 1)
49
+
50
+ def forward(self, x):
51
+ output = x
52
+ output = self.input_layer(x)
53
+ output = self.dropout_layer(output)
54
+ output = self.activation(output)
55
+ old_output = output
56
+ for layer_index, layer in enumerate(self.layers):
57
+ output = layer(output)
58
+ output = self.dropout_layer(output)
59
+ output = self.activation(output) + old_output
60
+ old_output = output
61
+ return output
62
+
63
+
64
+ def get_clones(module, N):
65
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
66
+
67
+
68
+ def pad_packed_tensor(input, lengths, value):
69
+ """pad_packed_tensor"""
70
+ old_shape = input.shape
71
+ device = input.device
72
+ if not isinstance(lengths, torch.Tensor):
73
+ lengths = torch.tensor(lengths, dtype=torch.int64, device=device)
74
+ else:
75
+ lengths = lengths.to(device)
76
+ max_len = (lengths.max()).item()
77
+
78
+ batch_size = len(lengths)
79
+ x = input.new(batch_size * max_len, *old_shape[1:])
80
+ x.fill_(value)
81
+
82
+ # Initialize a tensor with an index for every value in the array
83
+ index = torch.ones(len(input), dtype=torch.int64, device=device)
84
+
85
+ # Row shifts
86
+ row_shifts = torch.cumsum(max_len - lengths, 0)
87
+
88
+ # Calculate shifts for second row, third row... nth row (not the n+1th row)
89
+ # Expand this out to match the shape of all entries after the first row
90
+ row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:])
91
+
92
+ # Add this to the list of inds _after_ the first row
93
+ cumsum_inds = torch.cumsum(index, 0) - 1
94
+ cumsum_inds[lengths[0] :] += row_shifts_expanded
95
+ x[cumsum_inds] = input
96
+ return x.view(batch_size, max_len, *old_shape[1:])
nn_utils/transformer_layers.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """transformer_layer.py
2
+
3
+ Hold pairwise attention enabled transformers
4
+
5
+ """
6
+ import math
7
+ from typing import Optional, Union, Callable, Tuple
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.nn import functional as F
12
+ from torch.nn import Module, LayerNorm, Linear, Dropout, Parameter
13
+ from torch.nn.init import xavier_uniform_, constant_
14
+
15
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
16
+
17
+
18
+ class TransformerEncoderLayer(Module):
19
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
20
+ This standard encoder layer is based on the paper "Attention Is All You Need".
21
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
22
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
23
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
24
+ in a different way during application.
25
+
26
+ Args:
27
+ d_model: the number of expected features in the input (required).
28
+ nhead: the number of heads in the multiheadattention models (required).
29
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
30
+ dropout: the dropout value (default=0.1).
31
+ activation: the activation function of the intermediate layer, can be a string
32
+ ("relu" or "gelu") or a unary callable. Default: relu
33
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
34
+ batch_first: If ``True``, then the input and output tensors are provided
35
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
36
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
37
+ operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
38
+ additive_attn: if ``True``, use additive attn instead of scaled dot
39
+ product attention`
40
+ pairwise_featurization: If ``True``
41
+ Examples::
42
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
43
+ >>> src = torch.rand(10, 32, 512)
44
+ >>> out = encoder_layer(src)
45
+
46
+ Alternatively, when ``batch_first`` is ``True``:
47
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
48
+ >>> src = torch.rand(32, 10, 512)
49
+ >>> out = encoder_layer(src)
50
+ """
51
+ __constants__ = ["batch_first", "norm_first"]
52
+
53
+ def __init__(
54
+ self,
55
+ d_model: int,
56
+ nhead: int,
57
+ dim_feedforward: int = 2048,
58
+ dropout: float = 0.1,
59
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
60
+ layer_norm_eps: float = 1e-5,
61
+ batch_first: bool = False,
62
+ norm_first: bool = False,
63
+ additive_attn: bool = False,
64
+ pairwise_featurization: bool = False,
65
+ device=None,
66
+ dtype=None,
67
+ ) -> None:
68
+ factory_kwargs = {"device": device, "dtype": dtype}
69
+ super(TransformerEncoderLayer, self).__init__()
70
+ self.pairwise_featurization = pairwise_featurization
71
+ self.self_attn = MultiheadAttention(
72
+ d_model,
73
+ nhead,
74
+ dropout=dropout,
75
+ batch_first=batch_first,
76
+ additive_attn=additive_attn,
77
+ pairwise_featurization=self.pairwise_featurization,
78
+ **factory_kwargs,
79
+ )
80
+ # Implementation of Feedforward model
81
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
82
+ self.dropout = Dropout(dropout)
83
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
84
+
85
+ self.norm_first = norm_first
86
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
87
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
88
+ self.dropout1 = Dropout(dropout)
89
+ self.dropout2 = Dropout(dropout)
90
+
91
+ self.activation = activation
92
+
93
+ def __setstate__(self, state):
94
+ if "activation" not in state:
95
+ state["activation"] = F.relu
96
+ super(TransformerEncoderLayer, self).__setstate__(state)
97
+
98
+ def forward(
99
+ self,
100
+ src: Tensor,
101
+ pairwise_features: Optional[Tensor] = None,
102
+ src_key_padding_mask: Optional[Tensor] = None,
103
+ ) -> Tensor:
104
+ r"""Pass the input through the encoder layer.
105
+
106
+ Args:
107
+ src: the sequence to the encoder layer (required).
108
+ pairwise_features: If set, use this to param pariwise features
109
+ src_key_padding_mask: the mask for the src keys per batch (optional).
110
+
111
+ Shape:
112
+ see the docs in Transformer class.
113
+ """
114
+
115
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
116
+
117
+ x = src
118
+ if self.norm_first:
119
+ x = x + self._sa_block(
120
+ self.norm1(x), pairwise_features, src_key_padding_mask
121
+ )
122
+ x = x + self._ff_block(self.norm2(x))
123
+ else:
124
+ x = self.norm1(
125
+ x + self._sa_block(x, pairwise_features, src_key_padding_mask)
126
+ )
127
+ x = self.norm2(x + self._ff_block(x))
128
+
129
+ return x, pairwise_features
130
+
131
+ # self-attention block
132
+ def _sa_block(
133
+ self,
134
+ x: Tensor,
135
+ pairwise_features: Optional[Tensor],
136
+ key_padding_mask: Optional[Tensor],
137
+ ) -> Tensor:
138
+
139
+ ## Apply joint featurizer
140
+ x = self.self_attn(
141
+ x,
142
+ x,
143
+ x,
144
+ key_padding_mask=key_padding_mask,
145
+ pairwise_features=pairwise_features,
146
+ )[0]
147
+ return self.dropout1(x)
148
+
149
+ # feed forward block
150
+ def _ff_block(self, x: Tensor) -> Tensor:
151
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
152
+ return self.dropout2(x)
153
+
154
+
155
+ class MultiheadAttention(Module):
156
+ r"""Allows the model to jointly attend to information
157
+ from different representation subspaces as described in the paper:
158
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
159
+
160
+ Multi-Head Attention is defined as:
161
+
162
+ .. math::
163
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
164
+
165
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
166
+
167
+ Args:
168
+ embed_dim: Total dimension of the model.
169
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
170
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
171
+ additive_attn: If true, use additive attention instead of scaled dot
172
+ product attention
173
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
174
+ batch_first: If ``True``, then the input and output tensors are provided
175
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
176
+ pairwsie_featurization: If ``True``, use pairwise featurization on the
177
+ inputs
178
+
179
+ Examples::
180
+
181
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
182
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ embed_dim,
188
+ num_heads,
189
+ additive_attn=False,
190
+ pairwise_featurization: bool = False,
191
+ dropout=0.0,
192
+ batch_first=False,
193
+ device=None,
194
+ dtype=None,
195
+ ) -> None:
196
+ factory_kwargs = {"device": device, "dtype": dtype}
197
+ super(MultiheadAttention, self).__init__()
198
+
199
+ self.embed_dim = embed_dim
200
+ self.kdim = embed_dim
201
+ self.vdim = embed_dim
202
+ self._qkv_same_embed_dim = True
203
+ self.additive_attn = additive_attn
204
+ self.pairwise_featurization = pairwise_featurization
205
+
206
+ self.num_heads = num_heads
207
+ self.dropout = dropout
208
+ self.batch_first = batch_first
209
+ self.head_dim = embed_dim // num_heads
210
+ assert (
211
+ self.head_dim * num_heads == self.embed_dim
212
+ ), "embed_dim must be divisible by num_heads"
213
+ if self.additive_attn:
214
+ head_1_input = (
215
+ self.head_dim * 3 if self.pairwise_featurization else self.head_dim * 2
216
+ )
217
+ self.attn_weight_1_weight = Parameter(
218
+ torch.empty(
219
+ (self.num_heads, head_1_input, self.head_dim), **factory_kwargs
220
+ ),
221
+ )
222
+ self.attn_weight_1_bias = Parameter(
223
+ torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
224
+ )
225
+
226
+ self.attn_weight_2_weight = Parameter(
227
+ torch.empty((self.num_heads, self.head_dim, 1), **factory_kwargs),
228
+ )
229
+ self.attn_weight_2_bias = Parameter(
230
+ torch.empty((self.num_heads, 1), **factory_kwargs),
231
+ )
232
+ # self.attn_weight_1 = Linear(head_1_input, self.head_dim)
233
+ # self.attn_weight_2 = Linear(self.head_dim, 1)
234
+ else:
235
+ if self.pairwise_featurization:
236
+ ## Bias term u
237
+ ##
238
+ self.bias_u = Parameter(
239
+ torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
240
+ )
241
+ self.bias_v = Parameter(
242
+ torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
243
+ )
244
+
245
+ self.in_proj_weight = Parameter(
246
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
247
+ )
248
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
249
+ self.out_proj = NonDynamicallyQuantizableLinear(
250
+ embed_dim, embed_dim, bias=True, **factory_kwargs
251
+ )
252
+
253
+ self._reset_parameters()
254
+
255
+ def _reset_parameters(self):
256
+ """_reset_parameters."""
257
+ xavier_uniform_(self.in_proj_weight)
258
+ constant_(self.in_proj_bias, 0.0)
259
+ constant_(self.out_proj.bias, 0.0)
260
+ if self.additive_attn:
261
+ xavier_uniform_(self.attn_weight_1_weight)
262
+ xavier_uniform_(self.attn_weight_2_weight)
263
+ constant_(self.attn_weight_1_bias, 0.0)
264
+ constant_(self.attn_weight_2_bias, 0.0)
265
+ else:
266
+ if self.pairwise_featurization:
267
+ constant_(self.bias_u, 0.0)
268
+ constant_(self.bias_v, 0.0)
269
+
270
+ def forward(
271
+ self,
272
+ query: Tensor,
273
+ key: Tensor,
274
+ value: Tensor,
275
+ key_padding_mask: Optional[Tensor] = None,
276
+ pairwise_features: Optional[Tensor] = None,
277
+ ) -> Tuple[Tensor, Optional[Tensor]]:
278
+ r"""
279
+ Args:
280
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
281
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
282
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
283
+ Queries are compared against key-value pairs to produce the output.
284
+ See "Attention Is All You Need" for more details.
285
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
286
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
287
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
288
+ See "Attention Is All You Need" for more details.
289
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
290
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
291
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
292
+ See "Attention Is All You Need" for more details.
293
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
294
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
295
+ Binary and byte masks are supported.
296
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
297
+ the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
298
+ value will be ignored.
299
+ pairwise_features: If specified, use this in the attention mechanism.
300
+ Handled differently for scalar dot product and additive attn
301
+
302
+ Outputs:
303
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
304
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
305
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
306
+ embedding dimension ``embed_dim``.
307
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
308
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
309
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
310
+ :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
311
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
312
+
313
+ .. note::
314
+ `batch_first` argument is ignored for unbatched inputs.
315
+ """
316
+ is_batched = query.dim() == 3
317
+ if self.batch_first and is_batched:
318
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
319
+
320
+ ## Here!
321
+ attn_output, attn_output_weights = self.multi_head_attention_forward(
322
+ query,
323
+ key,
324
+ value,
325
+ self.embed_dim,
326
+ self.num_heads,
327
+ self.in_proj_weight,
328
+ self.in_proj_bias,
329
+ self.dropout,
330
+ self.out_proj.weight,
331
+ self.out_proj.bias,
332
+ training=self.training,
333
+ key_padding_mask=key_padding_mask,
334
+ pairwise_features=pairwise_features,
335
+ )
336
+
337
+ if self.batch_first and is_batched:
338
+ return attn_output.transpose(1, 0), attn_output_weights
339
+ else:
340
+ return attn_output, attn_output_weights
341
+
342
+ def multi_head_attention_forward(
343
+ self,
344
+ query: Tensor,
345
+ key: Tensor,
346
+ value: Tensor,
347
+ embed_dim_to_check: int,
348
+ num_heads: int,
349
+ in_proj_weight: Tensor,
350
+ in_proj_bias: Optional[Tensor],
351
+ dropout_p: float,
352
+ out_proj_weight: Tensor,
353
+ out_proj_bias: Optional[Tensor],
354
+ training: bool = True,
355
+ key_padding_mask: Optional[Tensor] = None,
356
+ pairwise_features: Optional[Tensor] = None,
357
+ ) -> Tuple[Tensor, Optional[Tensor]]:
358
+ r"""
359
+ Args:
360
+ query, key, value: map a query and a set of key-value pairs to an output.
361
+ See "Attention Is All You Need" for more details.
362
+ embed_dim_to_check: total dimension of the model.
363
+ num_heads: parallel attention heads.
364
+ in_proj_weight, in_proj_bias: input projection weight and bias.
365
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
366
+ add_zero_attn: add a new batch of zeros to the key and
367
+ value sequences at dim=1.
368
+ dropout_p: probability of an element to be zeroed.
369
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
370
+ training: apply dropout if is ``True``.
371
+ key_padding_mask: if provided, specified padding elements in the key will
372
+ be ignored by the attention. This is an binary mask. When the value is True,
373
+ the corresponding value on the attention layer will be filled with -inf.
374
+ pairwise_features: If provided, include this in the MHA
375
+ Shape:
376
+ Inputs:
377
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
378
+ the embedding dimension.
379
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
380
+ the embedding dimension.
381
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
382
+ the embedding dimension.
383
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
384
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
385
+ will be unchanged. If a BoolTensor is provided, the positions with the
386
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
387
+ Outputs:
388
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
389
+ E is the embedding dimension.
390
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
391
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
392
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
393
+ :math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
394
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
395
+ """
396
+
397
+ # set up shape vars
398
+ tgt_len, bsz, embed_dim = query.shape
399
+ src_len, _, _ = key.shape
400
+ assert (
401
+ embed_dim == embed_dim_to_check
402
+ ), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
403
+ if isinstance(embed_dim, torch.Tensor):
404
+ # embed_dim can be a tensor when JIT tracing
405
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
406
+ else:
407
+ head_dim = embed_dim // num_heads
408
+ assert (
409
+ head_dim * num_heads == embed_dim
410
+ ), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
411
+ assert (
412
+ key.shape == value.shape
413
+ ), f"key shape {key.shape} does not match value shape {value.shape}"
414
+
415
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
416
+
417
+ #
418
+ # reshape q, k, v for multihead attention and make em batch first
419
+ #
420
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
421
+ k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
422
+ v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
423
+
424
+ if pairwise_features is not None:
425
+ # Expand pairwise features, which should have dimension the size of
426
+ # the attn head dim
427
+ # B x L x L x H => L x L x (B*Nh) x (H/nh)
428
+ pairwise_features = pairwise_features.permute(1, 2, 0, 3).contiguous()
429
+ pairwise_features = pairwise_features.view(
430
+ tgt_len, tgt_len, bsz * num_heads, head_dim
431
+ )
432
+
433
+ # L x L x (B*Nh) x (H/nh) => (B*Nh) x L x L x (H / Nh)
434
+ pairwise_features = pairwise_features.permute(2, 0, 1, 3)
435
+
436
+ # Uncomment if we project into hidden dim only
437
+ # pairwise_features = pairwise_features.repeat_interleave(self.num_heads, 0)
438
+
439
+ # update source sequence length after adjustments
440
+ src_len = k.size(1)
441
+
442
+ # merge key padding and attention masks
443
+ attn_mask = None
444
+ if key_padding_mask is not None:
445
+ assert key_padding_mask.shape == (
446
+ bsz,
447
+ src_len,
448
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
449
+ key_padding_mask = (
450
+ key_padding_mask.view(bsz, 1, 1, src_len)
451
+ .expand(-1, num_heads, -1, -1)
452
+ .reshape(bsz * num_heads, 1, src_len)
453
+ )
454
+ attn_mask = key_padding_mask
455
+ assert attn_mask.dtype == torch.bool
456
+
457
+ # adjust dropout probability
458
+ if not training:
459
+ dropout_p = 0.0
460
+
461
+ #
462
+ # calculate attention and out projection
463
+ #
464
+ if self.additive_attn:
465
+ attn_output, attn_output_weights = self._additive_attn(
466
+ q, k, v, attn_mask, dropout_p, pairwise_features=pairwise_features
467
+ )
468
+ else:
469
+ attn_output, attn_output_weights = self._scaled_dot_product_attention(
470
+ q, k, v, attn_mask, dropout_p, pairwise_features=pairwise_features
471
+ )
472
+ # Editing
473
+ attn_output = (
474
+ attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
475
+ )
476
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
477
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
478
+
479
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
480
+ return attn_output, attn_output_weights
481
+
482
+ def _additive_attn(
483
+ self,
484
+ q: Tensor,
485
+ k: Tensor,
486
+ v: Tensor,
487
+ attn_mask: Optional[Tensor] = None,
488
+ dropout_p: float = 0.0,
489
+ pairwise_features: Optional[Tensor] = None,
490
+ ) -> Tuple[Tensor, Tensor]:
491
+ """_additive_attn.
492
+
493
+ Args:
494
+ q (Tensor): q
495
+ k (Tensor): k
496
+ v (Tensor): v
497
+ attn_mask (Optional[Tensor]): attn_mask
498
+ dropout_p (float): dropout_p
499
+ pairwise_features (Optional[Tensor]): pairwise_features
500
+
501
+ Returns:
502
+ Tuple[Tensor, Tensor]:
503
+ """
504
+ r"""
505
+ Computes scaled dot product attention on query, key and value tensors, using
506
+ an optional attention mask if passed, and applying dropout if a probability
507
+ greater than 0.0 is specified.
508
+ Returns a tensor pair containing attended values and attention weights.
509
+ Args:
510
+ q, k, v: query, key and value tensors. See Shape section for shape details.
511
+ attn_mask: optional tensor containing mask values to be added to calculated
512
+ attention. May be 2D or 3D; see Shape section for details.
513
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
514
+ pairwise_features: Optional tensor for pairwise
515
+ featurizations
516
+ Shape:
517
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
518
+ and E is embedding dimension.
519
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
520
+ and E is embedding dimension.
521
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
522
+ and E is embedding dimension.
523
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
524
+ shape :math:`(Nt, Ns)`.
525
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
526
+ have shape :math:`(B, Nt, Ns)`
527
+ """
528
+ # NOTE: Consider removing position i attending to itself?
529
+
530
+ B, Nt, E = q.shape
531
+ # Need linear layer here :/
532
+ # B x Nt x E => B x Nt x Nt x E
533
+ q_expand = q[:, :, None, :].expand(B, Nt, Nt, E)
534
+ v_expand = v[:, None, :, :].expand(B, Nt, Nt, E)
535
+ # B x Nt x Nt x E => B x Nt x Nt x 2E
536
+ cat_ar = [q_expand, v_expand]
537
+ if pairwise_features is not None:
538
+ cat_ar.append(pairwise_features)
539
+
540
+ output = torch.cat(cat_ar, -1)
541
+ E_long = E * len(cat_ar)
542
+
543
+ output = output.view(-1, self.num_heads, Nt, Nt, E_long)
544
+
545
+ # B x Nt x Nt x len(cat_ar)*E => B x Nt x Nt x E
546
+ ## This was a fixed attn weight for each head, now separating
547
+ # output = self.attn_weight_1(output)
548
+ output = torch.einsum("bnlwe,neh->bnlwh", output, self.attn_weight_1_weight)
549
+
550
+ output = output + self.attn_weight_1_bias[None, :, None, None, :]
551
+
552
+ output = F.leaky_relu(output)
553
+
554
+ # B x Nt x Nt x len(cat_ar)*E => B x Nt x Nt
555
+ # attn = self.attn_weight_2(output).squeeze()
556
+ attn = torch.einsum("bnlwh,nhi->bnlwi", output, self.attn_weight_2_weight)
557
+ attn = attn + self.attn_weight_2_bias[None, :, None, None, :]
558
+ attn = attn.contiguous().view(-1, Nt, Nt)
559
+ if attn_mask is not None:
560
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
561
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
562
+ attn += attn_mask
563
+ attn = F.softmax(attn, dim=-1)
564
+ output = torch.bmm(attn, v)
565
+ return output, attn
566
+
567
+ def _scaled_dot_product_attention(
568
+ self,
569
+ q: Tensor,
570
+ k: Tensor,
571
+ v: Tensor,
572
+ attn_mask: Optional[Tensor] = None,
573
+ dropout_p: float = 0.0,
574
+ pairwise_features: Optional[Tensor] = None,
575
+ ) -> Tuple[Tensor, Tensor]:
576
+ r"""
577
+ Computes scaled dot product attention on query, key and value tensors, using
578
+ an optional attention mask if passed, and applying dropout if a probability
579
+ greater than 0.0 is specified.
580
+ Returns a tensor pair containing attended values and attention weights.
581
+ Args:
582
+ q, k, v: query, key and value tensors. See Shape section for shape details.
583
+ attn_mask: optional tensor containing mask values to be added to calculated
584
+ attention. May be 2D or 3D; see Shape section for details.
585
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
586
+ pairwise_features: Optional tensor for pairwise
587
+ featurizations
588
+ Shape:
589
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
590
+ and E is embedding dimension.
591
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
592
+ and E is embedding dimension.
593
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
594
+ and E is embedding dimension.
595
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
596
+ shape :math:`(Nt, Ns)`.
597
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
598
+ have shape :math:`(B, Nt, Ns)`
599
+ """
600
+ B, Nt, E = q.shape
601
+ q = q / math.sqrt(E)
602
+
603
+ if self.pairwise_featurization:
604
+ ## Inspired by Graph2Smiles and TransformerXL
605
+ # We use pairwise embedding / corrections
606
+ if pairwise_features is None:
607
+ raise ValueError()
608
+
609
+ # B*Nh x Nt x E => B x Nh x Nt x E
610
+ q = q.view(-1, self.num_heads, Nt, E)
611
+ q_1 = q + self.bias_u[None, :, None, :]
612
+ q_2 = q + self.bias_v[None, :, None, :]
613
+
614
+ # B x Nh x Nt x E => B*Nh x Nt x E
615
+ q_1 = q_1.view(-1, Nt, E)
616
+ q_2 = q_2.view(-1, Nt, E)
617
+
618
+ # B x Nh x Nt x E => B x Nh x Nt x Nt
619
+ a_c = torch.einsum("ble,bwe->blw", q_1, k)
620
+
621
+ # pairwise: B*Nh x Nt x Nt x E
622
+ # q_2: B*Nh x Nt x E
623
+ b_d = torch.einsum("ble,blwe->blw", q_2, pairwise_features)
624
+
625
+ attn = a_c + b_d
626
+ else:
627
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
628
+ attn = torch.bmm(q, k.transpose(-2, -1))
629
+
630
+ if attn_mask is not None:
631
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
632
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
633
+ attn += attn_mask
634
+
635
+ attn = F.softmax(attn, dim=-1)
636
+ if dropout_p > 0.0:
637
+ attn = F.dropout(attn, p=dropout_p)
638
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
639
+ output = torch.bmm(attn, v)
640
+ return output, attn