Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- nn_utils/__init__.py +2 -0
- nn_utils/__pycache__/__init__.cpython-38.pyc +0 -0
- nn_utils/__pycache__/form_embedder.cpython-38.pyc +0 -0
- nn_utils/__pycache__/nn_utils.cpython-38.pyc +0 -0
- nn_utils/__pycache__/transformer_layers.cpython-38.pyc +0 -0
- nn_utils/base_hyperopt.py +146 -0
- nn_utils/form_embedder.py +291 -0
- nn_utils/gitkeep.txt +0 -0
- nn_utils/nn_utils.py +96 -0
- nn_utils/transformer_layers.py +640 -0
nn_utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .nn_utils import *
|
| 2 |
+
from .transformer_layers import *
|
nn_utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (202 Bytes). View file
|
|
|
nn_utils/__pycache__/form_embedder.cpython-38.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
nn_utils/__pycache__/nn_utils.cpython-38.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
nn_utils/__pycache__/transformer_layers.cpython-38.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
nn_utils/base_hyperopt.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" base_hyperopt.py
|
| 2 |
+
|
| 3 |
+
Abstract away common hyperopt functionality
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
import logging
|
| 7 |
+
import yaml
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import pytorch_lightning as pl
|
| 13 |
+
|
| 14 |
+
import ray
|
| 15 |
+
from ray import tune
|
| 16 |
+
from ray.air.config import RunConfig
|
| 17 |
+
from ray.tune.search import ConcurrencyLimiter
|
| 18 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 19 |
+
from ray.tune.schedulers.async_hyperband import ASHAScheduler
|
| 20 |
+
|
| 21 |
+
import mist_cf.common as common
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def add_hyperopt_args(parser):
|
| 25 |
+
# Tune args
|
| 26 |
+
ha = parser.add_argument_group("Hyperopt Args")
|
| 27 |
+
ha.add_argument("--cpus-per-trial", default=1, type=int)
|
| 28 |
+
ha.add_argument("--gpus-per-trial", default=1, type=float)
|
| 29 |
+
ha.add_argument("--num-h-samples", default=50, type=int)
|
| 30 |
+
ha.add_argument("--grace-period", default=60 * 15, type=int)
|
| 31 |
+
ha.add_argument("--max-concurrent", default=10, type=int)
|
| 32 |
+
ha.add_argument("--tune-checkpoint", default=None)
|
| 33 |
+
|
| 34 |
+
# Overwrite default savedir
|
| 35 |
+
time_name = datetime.now().strftime("%Y_%m_%d")
|
| 36 |
+
save_default = f"results/{time_name}_hyperopt/"
|
| 37 |
+
parser.set_defaults(save_dir=save_default)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_hyperopt(
|
| 41 |
+
kwargs: dict,
|
| 42 |
+
score_function: Callable,
|
| 43 |
+
param_space_function: Callable,
|
| 44 |
+
initial_points: list,
|
| 45 |
+
gen_shared_data: Callable = lambda params: {},
|
| 46 |
+
):
|
| 47 |
+
"""run_hyperopt.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
kwargs: All dictionary args for hyperopt and train
|
| 51 |
+
score_function: Trainable function that sets up model train
|
| 52 |
+
param_space_function: Function to suggest new params
|
| 53 |
+
initial_points: List of initial params to try
|
| 54 |
+
"""
|
| 55 |
+
# init ray with new session
|
| 56 |
+
ray.init(address="local")
|
| 57 |
+
|
| 58 |
+
# Fix base_args based upon tune args
|
| 59 |
+
kwargs["gpu"] = kwargs.get("gpus_per_trial", 0) > 0
|
| 60 |
+
# max_t = args.max_epochs
|
| 61 |
+
|
| 62 |
+
if kwargs["debug"]:
|
| 63 |
+
kwargs["num_h_samples"] = 10
|
| 64 |
+
kwargs["max_epochs"] = 5
|
| 65 |
+
|
| 66 |
+
save_dir = kwargs["save_dir"]
|
| 67 |
+
common.setup_logger(
|
| 68 |
+
save_dir, log_name="hyperopt.log", debug=kwargs.get("debug", False)
|
| 69 |
+
)
|
| 70 |
+
pl.utilities.seed.seed_everything(kwargs.get("seed"))
|
| 71 |
+
|
| 72 |
+
shared_args = gen_shared_data(kwargs)
|
| 73 |
+
|
| 74 |
+
# Define score function
|
| 75 |
+
trainable = tune.with_parameters(
|
| 76 |
+
score_function, base_args=kwargs, orig_dir=Path().resolve(), **shared_args
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Dump args
|
| 80 |
+
yaml_args = yaml.dump(kwargs)
|
| 81 |
+
logging.info(f"\n{yaml_args}")
|
| 82 |
+
with open(Path(save_dir) / "args.yaml", "w") as fp:
|
| 83 |
+
fp.write(yaml_args)
|
| 84 |
+
|
| 85 |
+
metric = "val_loss"
|
| 86 |
+
|
| 87 |
+
# Include cpus and gpus per trial
|
| 88 |
+
trainable = tune.with_resources(
|
| 89 |
+
trainable,
|
| 90 |
+
resources=tune.PlacementGroupFactory(
|
| 91 |
+
[
|
| 92 |
+
{
|
| 93 |
+
"CPU": kwargs.get("cpus_per_trial"),
|
| 94 |
+
"GPU": kwargs.get("gpus_per_trial"),
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"CPU": kwargs.get("num_workers"),
|
| 98 |
+
},
|
| 99 |
+
],
|
| 100 |
+
strategy="PACK",
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
search_algo = OptunaSearch(
|
| 105 |
+
metric=metric,
|
| 106 |
+
mode="min",
|
| 107 |
+
points_to_evaluate=initial_points,
|
| 108 |
+
space=param_space_function,
|
| 109 |
+
)
|
| 110 |
+
search_algo = ConcurrencyLimiter(
|
| 111 |
+
search_algo, max_concurrent=kwargs["max_concurrent"]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
tuner = tune.Tuner(
|
| 115 |
+
trainable,
|
| 116 |
+
tune_config=tune.TuneConfig(
|
| 117 |
+
mode="min",
|
| 118 |
+
metric=metric,
|
| 119 |
+
search_alg=search_algo,
|
| 120 |
+
scheduler=ASHAScheduler(
|
| 121 |
+
max_t=24 * 60 * 60, # max_t,
|
| 122 |
+
time_attr="time_total_s",
|
| 123 |
+
grace_period=kwargs.get("grace_period"),
|
| 124 |
+
reduction_factor=2,
|
| 125 |
+
),
|
| 126 |
+
num_samples=kwargs.get("num_h_samples"),
|
| 127 |
+
),
|
| 128 |
+
run_config=RunConfig(name=None, local_dir=kwargs["save_dir"]),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if kwargs.get("tune_checkpoint") is not None:
|
| 132 |
+
ckpt = str(Path(kwargs["tune_checkpoint"]).resolve())
|
| 133 |
+
tuner = tuner.restore(path=ckpt, restart_errored=True)
|
| 134 |
+
|
| 135 |
+
results = tuner.fit()
|
| 136 |
+
best_trial = results.get_best_result()
|
| 137 |
+
output = {"score": best_trial.metrics[metric], "config": best_trial.config}
|
| 138 |
+
out_str = yaml.dump(output, indent=2)
|
| 139 |
+
logging.info(out_str)
|
| 140 |
+
with open(Path(save_dir) / "best_trial.yaml", "w") as f:
|
| 141 |
+
f.write(out_str)
|
| 142 |
+
|
| 143 |
+
# Output full res table
|
| 144 |
+
results.get_dataframe().to_csv(
|
| 145 |
+
Path(save_dir) / "full_res_tbl.tsv", sep="\t", index=None
|
| 146 |
+
)
|
nn_utils/form_embedder.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import mist_cf.common as common
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class IntFeaturizer(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Base class for mapping integers to a vector representation (primarily to be used as a "richer" embedding for NNs
|
| 11 |
+
processing integers).
|
| 12 |
+
|
| 13 |
+
Subclasses should define `self.int_to_feat_matrix`, a matrix where each row is the vector representation for that
|
| 14 |
+
integer, i.e. to get a vector representation for `5`, one could call `self.int_to_feat_matrix[5]`.
|
| 15 |
+
|
| 16 |
+
Note that this class takes care of creating a fixed number (`self.NUM_EXTRA_EMBEDDINGS` to be precise) of extra
|
| 17 |
+
"learned" embeddings these will be concatenated after the integer embeddings in the forward pass,
|
| 18 |
+
be learned, and be used for extra non-integer tokens such as the "to be confirmed token" (i.e., pad) token.
|
| 19 |
+
They are indexed starting from `self.MAX_COUNT_INT`.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
MAX_COUNT_INT = 255 # the maximum number of integers that we are going to see as a "count", i.e. 0 to MAX_COUNT_INT-1
|
| 23 |
+
NUM_EXTRA_EMBEDDINGS = 1 # Number of extra embeddings to learn -- one for the "to be confirmed" embedding.
|
| 24 |
+
|
| 25 |
+
def __init__(self, embedding_dim):
|
| 26 |
+
super().__init__()
|
| 27 |
+
weights = torch.zeros(self.NUM_EXTRA_EMBEDDINGS, embedding_dim)
|
| 28 |
+
self._extra_embeddings = nn.Parameter(weights, requires_grad=True)
|
| 29 |
+
nn.init.normal_(self._extra_embeddings, 0.0, 1.0)
|
| 30 |
+
self.embedding_dim = embedding_dim
|
| 31 |
+
|
| 32 |
+
def forward(self, tensor):
|
| 33 |
+
"""
|
| 34 |
+
Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
|
| 35 |
+
"""
|
| 36 |
+
# todo(jab): copied this code from the original in-built binarizer embedder in built into the class.
|
| 37 |
+
# very similar to F.embedding but we want to put the embedding into the final dimension -- could ask Sam
|
| 38 |
+
# why...
|
| 39 |
+
|
| 40 |
+
orig_shape = tensor.shape
|
| 41 |
+
out_tensor = torch.empty(
|
| 42 |
+
(*orig_shape, self.embedding_dim), device=tensor.device
|
| 43 |
+
)
|
| 44 |
+
extra_embed = tensor >= self.MAX_COUNT_INT
|
| 45 |
+
|
| 46 |
+
tensor = tensor.long()
|
| 47 |
+
norm_embeds = self.int_to_feat_matrix[tensor[~extra_embed]]
|
| 48 |
+
extra_embeds = self._extra_embeddings[tensor[extra_embed] - self.MAX_COUNT_INT]
|
| 49 |
+
|
| 50 |
+
out_tensor[~extra_embed] = norm_embeds
|
| 51 |
+
out_tensor[extra_embed] = extra_embeds
|
| 52 |
+
|
| 53 |
+
temp_out = out_tensor.reshape(*orig_shape[:-1], -1)
|
| 54 |
+
return temp_out
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def num_dim(self):
|
| 58 |
+
return self.int_to_feat_matrix.shape[1]
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def full_dim(self):
|
| 62 |
+
return self.num_dim * common.NORM_VEC.shape[0]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Binarizer(IntFeaturizer):
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super().__init__(embedding_dim=len(common.num_to_binary(0)))
|
| 68 |
+
int_to_binary_repr = np.vstack(
|
| 69 |
+
[common.num_to_binary(i) for i in range(self.MAX_COUNT_INT)]
|
| 70 |
+
)
|
| 71 |
+
int_to_binary_repr = torch.from_numpy(int_to_binary_repr)
|
| 72 |
+
self.int_to_feat_matrix = nn.Parameter(int_to_binary_repr.float())
|
| 73 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class FourierFeaturizer(IntFeaturizer):
|
| 77 |
+
"""
|
| 78 |
+
Inspired by:
|
| 79 |
+
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
|
| 80 |
+
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
|
| 81 |
+
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
|
| 82 |
+
|
| 83 |
+
Some notes:
|
| 84 |
+
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
|
| 85 |
+
Binarizer quite closely but be a bit smoother.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self):
|
| 89 |
+
|
| 90 |
+
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
|
| 91 |
+
# ^ need at least this many to ensure that the whole input range can be represented on the half circle.
|
| 92 |
+
|
| 93 |
+
freqs = 0.5 ** torch.arange(num_freqs, dtype=torch.float32)
|
| 94 |
+
freqs_time_2pi = 2 * np.pi * freqs
|
| 95 |
+
|
| 96 |
+
super().__init__(
|
| 97 |
+
embedding_dim=2 * freqs_time_2pi.shape[0]
|
| 98 |
+
) # 2 for cosine and sine
|
| 99 |
+
|
| 100 |
+
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
|
| 101 |
+
combo_of_sinusoid_args = (
|
| 102 |
+
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
|
| 103 |
+
* freqs_time_2pi[None, :]
|
| 104 |
+
)
|
| 105 |
+
all_features = torch.cat(
|
| 106 |
+
[torch.cos(combo_of_sinusoid_args), torch.sin(combo_of_sinusoid_args)],
|
| 107 |
+
dim=1,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
|
| 111 |
+
self.int_to_feat_matrix = nn.Parameter(all_features.float())
|
| 112 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class FourierFeaturizerSines(IntFeaturizer):
|
| 116 |
+
"""
|
| 117 |
+
Like other fourier feats but sines only
|
| 118 |
+
|
| 119 |
+
Inspired by:
|
| 120 |
+
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
|
| 121 |
+
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
|
| 122 |
+
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
|
| 123 |
+
|
| 124 |
+
Some notes:
|
| 125 |
+
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
|
| 126 |
+
Binarizer quite closely but be a bit smoother.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self):
|
| 130 |
+
|
| 131 |
+
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
|
| 132 |
+
# ^ need at least this many to ensure that the whole input range can be represented on the half circle.
|
| 133 |
+
|
| 134 |
+
freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
|
| 135 |
+
freqs_time_2pi = 2 * np.pi * freqs
|
| 136 |
+
|
| 137 |
+
super().__init__(embedding_dim=freqs_time_2pi.shape[0])
|
| 138 |
+
|
| 139 |
+
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
|
| 140 |
+
combo_of_sinusoid_args = (
|
| 141 |
+
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
|
| 142 |
+
* freqs_time_2pi[None, :]
|
| 143 |
+
)
|
| 144 |
+
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
|
| 145 |
+
self.int_to_feat_matrix = nn.Parameter(
|
| 146 |
+
torch.sin(combo_of_sinusoid_args).float()
|
| 147 |
+
)
|
| 148 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class FourierFeaturizerAbsoluteSines(IntFeaturizer):
|
| 152 |
+
"""
|
| 153 |
+
Like other fourier feats but sines only and absoluted.
|
| 154 |
+
|
| 155 |
+
Inspired by:
|
| 156 |
+
Tancik, M., Srinivasan, P.P., Mildenhall, B., Fridovich-Keil, S., Raghavan, N., Singhal, U., Ramamoorthi, R.,
|
| 157 |
+
Barron, J.T. and Ng, R. (2020) ‘Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional
|
| 158 |
+
Domains’, arXiv [cs.CV]. Available at: http://arxiv.org/abs/2006.10739.
|
| 159 |
+
|
| 160 |
+
Some notes:
|
| 161 |
+
* we'll put the frequencies at powers of 1/2 rather than random Gaussian samples; this means it will match the
|
| 162 |
+
Binarizer quite closely but be a bit smoother.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self):
|
| 166 |
+
|
| 167 |
+
num_freqs = int(np.ceil(np.log2(self.MAX_COUNT_INT))) + 2
|
| 168 |
+
|
| 169 |
+
freqs = (0.5 ** torch.arange(num_freqs, dtype=torch.float32))[2:]
|
| 170 |
+
freqs_time_2pi = 2 * np.pi * freqs
|
| 171 |
+
|
| 172 |
+
super().__init__(embedding_dim=freqs_time_2pi.shape[0])
|
| 173 |
+
|
| 174 |
+
# we will define the features at this frequency up front (as we only will ever see a fixed number of counts):
|
| 175 |
+
combo_of_sinusoid_args = (
|
| 176 |
+
torch.arange(self.MAX_COUNT_INT, dtype=torch.float32)[:, None]
|
| 177 |
+
* freqs_time_2pi[None, :]
|
| 178 |
+
)
|
| 179 |
+
# ^ shape: MAX_COUNT_INT x 2 * num_freqs
|
| 180 |
+
self.int_to_feat_matrix = nn.Parameter(
|
| 181 |
+
torch.abs(torch.sin(combo_of_sinusoid_args)).float()
|
| 182 |
+
)
|
| 183 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class RBFFeaturizer(IntFeaturizer):
|
| 187 |
+
"""
|
| 188 |
+
A featurizer that puts radial basis functions evenly between 0 and max_count-1. These will have a width of
|
| 189 |
+
(max_count-1) / (num_funcs) to decay to about 0.6 of its original height at reaching the next func.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
def __init__(self, num_funcs=32):
|
| 194 |
+
"""
|
| 195 |
+
:param num_funcs: number of radial basis functions to use: their width will automatically be chosen -- see class
|
| 196 |
+
docstring.
|
| 197 |
+
"""
|
| 198 |
+
super().__init__(embedding_dim=num_funcs)
|
| 199 |
+
width = (self.MAX_COUNT_INT - 1) / num_funcs
|
| 200 |
+
centers = torch.linspace(0, self.MAX_COUNT_INT - 1, num_funcs)
|
| 201 |
+
|
| 202 |
+
pre_exponential_terms = (
|
| 203 |
+
-0.5
|
| 204 |
+
* ((torch.arange(self.MAX_COUNT_INT)[:, None] - centers[None, :]) / width)
|
| 205 |
+
** 2
|
| 206 |
+
)
|
| 207 |
+
# ^ shape: MAX_COUNT_INT x num_funcs
|
| 208 |
+
feats = torch.exp(pre_exponential_terms)
|
| 209 |
+
|
| 210 |
+
self.int_to_feat_matrix = nn.Parameter(feats.float())
|
| 211 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class OneHotFeaturizer(IntFeaturizer):
|
| 215 |
+
"""
|
| 216 |
+
A featurizer that turns integers into their one hot encoding.
|
| 217 |
+
|
| 218 |
+
Represents:
|
| 219 |
+
- 0 as 1000000000...
|
| 220 |
+
- 1 as 0100000000...
|
| 221 |
+
- 2 as 0010000000...
|
| 222 |
+
and so on.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
def __init__(self):
|
| 226 |
+
super().__init__(embedding_dim=self.MAX_COUNT_INT)
|
| 227 |
+
feats = torch.eye(self.MAX_COUNT_INT)
|
| 228 |
+
self.int_to_feat_matrix = nn.Parameter(feats.float())
|
| 229 |
+
self.int_to_feat_matrix.requires_grad = False
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class LearnedFeaturizer(IntFeaturizer):
|
| 233 |
+
"""
|
| 234 |
+
Learns the features for the different integers.
|
| 235 |
+
|
| 236 |
+
Pretty much `nn.Embedding` but we get to use the forward of the superclass which behaves a bit differently.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(self, feature_dim=32):
|
| 240 |
+
super().__init__(embedding_dim=feature_dim)
|
| 241 |
+
weights = torch.zeros(self.MAX_COUNT_INT, feature_dim)
|
| 242 |
+
self.int_to_feat_matrix = nn.Parameter(weights, requires_grad=True)
|
| 243 |
+
nn.init.normal_(self.int_to_feat_matrix, 0.0, 1.0)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class FloatFeaturizer(IntFeaturizer):
|
| 247 |
+
"""
|
| 248 |
+
Norms the features
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
def __init__(self):
|
| 252 |
+
# Norm vec
|
| 253 |
+
# Placeholder..
|
| 254 |
+
super().__init__(embedding_dim=1)
|
| 255 |
+
self.norm_vec = torch.from_numpy(common.NORM_VEC).float()
|
| 256 |
+
self.norm_vec = nn.Parameter(self.norm_vec)
|
| 257 |
+
self.norm_vec.requires_grad = False
|
| 258 |
+
|
| 259 |
+
def forward(self, tensor):
|
| 260 |
+
"""
|
| 261 |
+
Convert the integer `tensor` into its new representation -- note that it gets stacked along final dimension.
|
| 262 |
+
"""
|
| 263 |
+
tens_shape = tensor.shape
|
| 264 |
+
out_shape = [1] * (len(tens_shape) - 1) + [-1]
|
| 265 |
+
return tensor / self.norm_vec.reshape(*out_shape)
|
| 266 |
+
|
| 267 |
+
@property
|
| 268 |
+
def num_dim(self):
|
| 269 |
+
return 1
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_embedder(embedder):
|
| 273 |
+
if embedder == "binary":
|
| 274 |
+
embedder = Binarizer()
|
| 275 |
+
elif embedder == "fourier":
|
| 276 |
+
embedder = FourierFeaturizer()
|
| 277 |
+
elif embedder == "rbf":
|
| 278 |
+
embedder = RBFFeaturizer()
|
| 279 |
+
elif embedder == "one-hot":
|
| 280 |
+
embedder = OneHotFeaturizer()
|
| 281 |
+
elif embedder == "learnt":
|
| 282 |
+
embedder = LearnedFeaturizer()
|
| 283 |
+
elif embedder == "float":
|
| 284 |
+
embedder = FloatFeaturizer()
|
| 285 |
+
elif embedder == "fourier-sines":
|
| 286 |
+
embedder = FourierFeaturizerSines()
|
| 287 |
+
elif embedder == "abs-sines":
|
| 288 |
+
embedder = FourierFeaturizerAbsoluteSines()
|
| 289 |
+
else:
|
| 290 |
+
raise NotImplementedError
|
| 291 |
+
return embedder
|
nn_utils/gitkeep.txt
ADDED
|
File without changes
|
nn_utils/nn_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" nn_utils.py
|
| 2 |
+
"""
|
| 3 |
+
import math
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_lr_scheduler(
|
| 11 |
+
optimizer, lr_decay_rate: float, decay_steps: int = 5000, warmup: int = 100
|
| 12 |
+
):
|
| 13 |
+
"""build_lr_scheduler.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
optimizer:
|
| 17 |
+
lr_decay_rate (float): lr_decay_rate
|
| 18 |
+
decay_steps (int): decay_steps
|
| 19 |
+
warmup_steps (int): warmup_steps
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def lr_lambda(step):
|
| 23 |
+
if step >= warmup:
|
| 24 |
+
# Adjust
|
| 25 |
+
step = step - warmup
|
| 26 |
+
rate = lr_decay_rate ** (step // decay_steps)
|
| 27 |
+
else:
|
| 28 |
+
rate = 1 - math.exp(-step / warmup)
|
| 29 |
+
return rate
|
| 30 |
+
|
| 31 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 32 |
+
return scheduler
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MLPBlocks(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
input_size: int,
|
| 39 |
+
hidden_size: int,
|
| 40 |
+
dropout: float,
|
| 41 |
+
num_layers: int,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.activation = nn.ReLU()
|
| 45 |
+
self.dropout_layer = nn.Dropout(p=dropout)
|
| 46 |
+
self.input_layer = nn.Linear(input_size, hidden_size)
|
| 47 |
+
middle_layer = nn.Linear(hidden_size, hidden_size)
|
| 48 |
+
self.layers = get_clones(middle_layer, num_layers - 1)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
output = x
|
| 52 |
+
output = self.input_layer(x)
|
| 53 |
+
output = self.dropout_layer(output)
|
| 54 |
+
output = self.activation(output)
|
| 55 |
+
old_output = output
|
| 56 |
+
for layer_index, layer in enumerate(self.layers):
|
| 57 |
+
output = layer(output)
|
| 58 |
+
output = self.dropout_layer(output)
|
| 59 |
+
output = self.activation(output) + old_output
|
| 60 |
+
old_output = output
|
| 61 |
+
return output
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_clones(module, N):
|
| 65 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def pad_packed_tensor(input, lengths, value):
|
| 69 |
+
"""pad_packed_tensor"""
|
| 70 |
+
old_shape = input.shape
|
| 71 |
+
device = input.device
|
| 72 |
+
if not isinstance(lengths, torch.Tensor):
|
| 73 |
+
lengths = torch.tensor(lengths, dtype=torch.int64, device=device)
|
| 74 |
+
else:
|
| 75 |
+
lengths = lengths.to(device)
|
| 76 |
+
max_len = (lengths.max()).item()
|
| 77 |
+
|
| 78 |
+
batch_size = len(lengths)
|
| 79 |
+
x = input.new(batch_size * max_len, *old_shape[1:])
|
| 80 |
+
x.fill_(value)
|
| 81 |
+
|
| 82 |
+
# Initialize a tensor with an index for every value in the array
|
| 83 |
+
index = torch.ones(len(input), dtype=torch.int64, device=device)
|
| 84 |
+
|
| 85 |
+
# Row shifts
|
| 86 |
+
row_shifts = torch.cumsum(max_len - lengths, 0)
|
| 87 |
+
|
| 88 |
+
# Calculate shifts for second row, third row... nth row (not the n+1th row)
|
| 89 |
+
# Expand this out to match the shape of all entries after the first row
|
| 90 |
+
row_shifts_expanded = row_shifts[:-1].repeat_interleave(lengths[1:])
|
| 91 |
+
|
| 92 |
+
# Add this to the list of inds _after_ the first row
|
| 93 |
+
cumsum_inds = torch.cumsum(index, 0) - 1
|
| 94 |
+
cumsum_inds[lengths[0] :] += row_shifts_expanded
|
| 95 |
+
x[cumsum_inds] = input
|
| 96 |
+
return x.view(batch_size, max_len, *old_shape[1:])
|
nn_utils/transformer_layers.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""transformer_layer.py
|
| 2 |
+
|
| 3 |
+
Hold pairwise attention enabled transformers
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, Union, Callable, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from torch.nn import Module, LayerNorm, Linear, Dropout, Parameter
|
| 13 |
+
from torch.nn.init import xavier_uniform_, constant_
|
| 14 |
+
|
| 15 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TransformerEncoderLayer(Module):
|
| 19 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
| 20 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
| 21 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
| 22 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
| 23 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
| 24 |
+
in a different way during application.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
d_model: the number of expected features in the input (required).
|
| 28 |
+
nhead: the number of heads in the multiheadattention models (required).
|
| 29 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
| 30 |
+
dropout: the dropout value (default=0.1).
|
| 31 |
+
activation: the activation function of the intermediate layer, can be a string
|
| 32 |
+
("relu" or "gelu") or a unary callable. Default: relu
|
| 33 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
| 34 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 35 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 36 |
+
norm_first: if ``True``, layer norm is done prior to attention and feedforward
|
| 37 |
+
operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
|
| 38 |
+
additive_attn: if ``True``, use additive attn instead of scaled dot
|
| 39 |
+
product attention`
|
| 40 |
+
pairwise_featurization: If ``True``
|
| 41 |
+
Examples::
|
| 42 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
| 43 |
+
>>> src = torch.rand(10, 32, 512)
|
| 44 |
+
>>> out = encoder_layer(src)
|
| 45 |
+
|
| 46 |
+
Alternatively, when ``batch_first`` is ``True``:
|
| 47 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
| 48 |
+
>>> src = torch.rand(32, 10, 512)
|
| 49 |
+
>>> out = encoder_layer(src)
|
| 50 |
+
"""
|
| 51 |
+
__constants__ = ["batch_first", "norm_first"]
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
d_model: int,
|
| 56 |
+
nhead: int,
|
| 57 |
+
dim_feedforward: int = 2048,
|
| 58 |
+
dropout: float = 0.1,
|
| 59 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 60 |
+
layer_norm_eps: float = 1e-5,
|
| 61 |
+
batch_first: bool = False,
|
| 62 |
+
norm_first: bool = False,
|
| 63 |
+
additive_attn: bool = False,
|
| 64 |
+
pairwise_featurization: bool = False,
|
| 65 |
+
device=None,
|
| 66 |
+
dtype=None,
|
| 67 |
+
) -> None:
|
| 68 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 69 |
+
super(TransformerEncoderLayer, self).__init__()
|
| 70 |
+
self.pairwise_featurization = pairwise_featurization
|
| 71 |
+
self.self_attn = MultiheadAttention(
|
| 72 |
+
d_model,
|
| 73 |
+
nhead,
|
| 74 |
+
dropout=dropout,
|
| 75 |
+
batch_first=batch_first,
|
| 76 |
+
additive_attn=additive_attn,
|
| 77 |
+
pairwise_featurization=self.pairwise_featurization,
|
| 78 |
+
**factory_kwargs,
|
| 79 |
+
)
|
| 80 |
+
# Implementation of Feedforward model
|
| 81 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
| 82 |
+
self.dropout = Dropout(dropout)
|
| 83 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
| 84 |
+
|
| 85 |
+
self.norm_first = norm_first
|
| 86 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 87 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 88 |
+
self.dropout1 = Dropout(dropout)
|
| 89 |
+
self.dropout2 = Dropout(dropout)
|
| 90 |
+
|
| 91 |
+
self.activation = activation
|
| 92 |
+
|
| 93 |
+
def __setstate__(self, state):
|
| 94 |
+
if "activation" not in state:
|
| 95 |
+
state["activation"] = F.relu
|
| 96 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
| 97 |
+
|
| 98 |
+
def forward(
|
| 99 |
+
self,
|
| 100 |
+
src: Tensor,
|
| 101 |
+
pairwise_features: Optional[Tensor] = None,
|
| 102 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
r"""Pass the input through the encoder layer.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
src: the sequence to the encoder layer (required).
|
| 108 |
+
pairwise_features: If set, use this to param pariwise features
|
| 109 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 110 |
+
|
| 111 |
+
Shape:
|
| 112 |
+
see the docs in Transformer class.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
| 116 |
+
|
| 117 |
+
x = src
|
| 118 |
+
if self.norm_first:
|
| 119 |
+
x = x + self._sa_block(
|
| 120 |
+
self.norm1(x), pairwise_features, src_key_padding_mask
|
| 121 |
+
)
|
| 122 |
+
x = x + self._ff_block(self.norm2(x))
|
| 123 |
+
else:
|
| 124 |
+
x = self.norm1(
|
| 125 |
+
x + self._sa_block(x, pairwise_features, src_key_padding_mask)
|
| 126 |
+
)
|
| 127 |
+
x = self.norm2(x + self._ff_block(x))
|
| 128 |
+
|
| 129 |
+
return x, pairwise_features
|
| 130 |
+
|
| 131 |
+
# self-attention block
|
| 132 |
+
def _sa_block(
|
| 133 |
+
self,
|
| 134 |
+
x: Tensor,
|
| 135 |
+
pairwise_features: Optional[Tensor],
|
| 136 |
+
key_padding_mask: Optional[Tensor],
|
| 137 |
+
) -> Tensor:
|
| 138 |
+
|
| 139 |
+
## Apply joint featurizer
|
| 140 |
+
x = self.self_attn(
|
| 141 |
+
x,
|
| 142 |
+
x,
|
| 143 |
+
x,
|
| 144 |
+
key_padding_mask=key_padding_mask,
|
| 145 |
+
pairwise_features=pairwise_features,
|
| 146 |
+
)[0]
|
| 147 |
+
return self.dropout1(x)
|
| 148 |
+
|
| 149 |
+
# feed forward block
|
| 150 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 151 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 152 |
+
return self.dropout2(x)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class MultiheadAttention(Module):
|
| 156 |
+
r"""Allows the model to jointly attend to information
|
| 157 |
+
from different representation subspaces as described in the paper:
|
| 158 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
| 159 |
+
|
| 160 |
+
Multi-Head Attention is defined as:
|
| 161 |
+
|
| 162 |
+
.. math::
|
| 163 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 164 |
+
|
| 165 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
embed_dim: Total dimension of the model.
|
| 169 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
| 170 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
| 171 |
+
additive_attn: If true, use additive attention instead of scaled dot
|
| 172 |
+
product attention
|
| 173 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
| 174 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 175 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 176 |
+
pairwsie_featurization: If ``True``, use pairwise featurization on the
|
| 177 |
+
inputs
|
| 178 |
+
|
| 179 |
+
Examples::
|
| 180 |
+
|
| 181 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 182 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
embed_dim,
|
| 188 |
+
num_heads,
|
| 189 |
+
additive_attn=False,
|
| 190 |
+
pairwise_featurization: bool = False,
|
| 191 |
+
dropout=0.0,
|
| 192 |
+
batch_first=False,
|
| 193 |
+
device=None,
|
| 194 |
+
dtype=None,
|
| 195 |
+
) -> None:
|
| 196 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 197 |
+
super(MultiheadAttention, self).__init__()
|
| 198 |
+
|
| 199 |
+
self.embed_dim = embed_dim
|
| 200 |
+
self.kdim = embed_dim
|
| 201 |
+
self.vdim = embed_dim
|
| 202 |
+
self._qkv_same_embed_dim = True
|
| 203 |
+
self.additive_attn = additive_attn
|
| 204 |
+
self.pairwise_featurization = pairwise_featurization
|
| 205 |
+
|
| 206 |
+
self.num_heads = num_heads
|
| 207 |
+
self.dropout = dropout
|
| 208 |
+
self.batch_first = batch_first
|
| 209 |
+
self.head_dim = embed_dim // num_heads
|
| 210 |
+
assert (
|
| 211 |
+
self.head_dim * num_heads == self.embed_dim
|
| 212 |
+
), "embed_dim must be divisible by num_heads"
|
| 213 |
+
if self.additive_attn:
|
| 214 |
+
head_1_input = (
|
| 215 |
+
self.head_dim * 3 if self.pairwise_featurization else self.head_dim * 2
|
| 216 |
+
)
|
| 217 |
+
self.attn_weight_1_weight = Parameter(
|
| 218 |
+
torch.empty(
|
| 219 |
+
(self.num_heads, head_1_input, self.head_dim), **factory_kwargs
|
| 220 |
+
),
|
| 221 |
+
)
|
| 222 |
+
self.attn_weight_1_bias = Parameter(
|
| 223 |
+
torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self.attn_weight_2_weight = Parameter(
|
| 227 |
+
torch.empty((self.num_heads, self.head_dim, 1), **factory_kwargs),
|
| 228 |
+
)
|
| 229 |
+
self.attn_weight_2_bias = Parameter(
|
| 230 |
+
torch.empty((self.num_heads, 1), **factory_kwargs),
|
| 231 |
+
)
|
| 232 |
+
# self.attn_weight_1 = Linear(head_1_input, self.head_dim)
|
| 233 |
+
# self.attn_weight_2 = Linear(self.head_dim, 1)
|
| 234 |
+
else:
|
| 235 |
+
if self.pairwise_featurization:
|
| 236 |
+
## Bias term u
|
| 237 |
+
##
|
| 238 |
+
self.bias_u = Parameter(
|
| 239 |
+
torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
|
| 240 |
+
)
|
| 241 |
+
self.bias_v = Parameter(
|
| 242 |
+
torch.empty((self.num_heads, self.head_dim), **factory_kwargs),
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.in_proj_weight = Parameter(
|
| 246 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
| 247 |
+
)
|
| 248 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
| 249 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
| 250 |
+
embed_dim, embed_dim, bias=True, **factory_kwargs
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self._reset_parameters()
|
| 254 |
+
|
| 255 |
+
def _reset_parameters(self):
|
| 256 |
+
"""_reset_parameters."""
|
| 257 |
+
xavier_uniform_(self.in_proj_weight)
|
| 258 |
+
constant_(self.in_proj_bias, 0.0)
|
| 259 |
+
constant_(self.out_proj.bias, 0.0)
|
| 260 |
+
if self.additive_attn:
|
| 261 |
+
xavier_uniform_(self.attn_weight_1_weight)
|
| 262 |
+
xavier_uniform_(self.attn_weight_2_weight)
|
| 263 |
+
constant_(self.attn_weight_1_bias, 0.0)
|
| 264 |
+
constant_(self.attn_weight_2_bias, 0.0)
|
| 265 |
+
else:
|
| 266 |
+
if self.pairwise_featurization:
|
| 267 |
+
constant_(self.bias_u, 0.0)
|
| 268 |
+
constant_(self.bias_v, 0.0)
|
| 269 |
+
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
query: Tensor,
|
| 273 |
+
key: Tensor,
|
| 274 |
+
value: Tensor,
|
| 275 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 276 |
+
pairwise_features: Optional[Tensor] = None,
|
| 277 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 278 |
+
r"""
|
| 279 |
+
Args:
|
| 280 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
| 281 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
| 282 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
| 283 |
+
Queries are compared against key-value pairs to produce the output.
|
| 284 |
+
See "Attention Is All You Need" for more details.
|
| 285 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
| 286 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
| 287 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
| 288 |
+
See "Attention Is All You Need" for more details.
|
| 289 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
| 290 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
| 291 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
| 292 |
+
See "Attention Is All You Need" for more details.
|
| 293 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
| 294 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
| 295 |
+
Binary and byte masks are supported.
|
| 296 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
| 297 |
+
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
|
| 298 |
+
value will be ignored.
|
| 299 |
+
pairwise_features: If specified, use this in the attention mechanism.
|
| 300 |
+
Handled differently for scalar dot product and additive attn
|
| 301 |
+
|
| 302 |
+
Outputs:
|
| 303 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
| 304 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
| 305 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
| 306 |
+
embedding dimension ``embed_dim``.
|
| 307 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
| 308 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
| 309 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
| 310 |
+
:math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
|
| 311 |
+
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
| 312 |
+
|
| 313 |
+
.. note::
|
| 314 |
+
`batch_first` argument is ignored for unbatched inputs.
|
| 315 |
+
"""
|
| 316 |
+
is_batched = query.dim() == 3
|
| 317 |
+
if self.batch_first and is_batched:
|
| 318 |
+
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
| 319 |
+
|
| 320 |
+
## Here!
|
| 321 |
+
attn_output, attn_output_weights = self.multi_head_attention_forward(
|
| 322 |
+
query,
|
| 323 |
+
key,
|
| 324 |
+
value,
|
| 325 |
+
self.embed_dim,
|
| 326 |
+
self.num_heads,
|
| 327 |
+
self.in_proj_weight,
|
| 328 |
+
self.in_proj_bias,
|
| 329 |
+
self.dropout,
|
| 330 |
+
self.out_proj.weight,
|
| 331 |
+
self.out_proj.bias,
|
| 332 |
+
training=self.training,
|
| 333 |
+
key_padding_mask=key_padding_mask,
|
| 334 |
+
pairwise_features=pairwise_features,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if self.batch_first and is_batched:
|
| 338 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
| 339 |
+
else:
|
| 340 |
+
return attn_output, attn_output_weights
|
| 341 |
+
|
| 342 |
+
def multi_head_attention_forward(
|
| 343 |
+
self,
|
| 344 |
+
query: Tensor,
|
| 345 |
+
key: Tensor,
|
| 346 |
+
value: Tensor,
|
| 347 |
+
embed_dim_to_check: int,
|
| 348 |
+
num_heads: int,
|
| 349 |
+
in_proj_weight: Tensor,
|
| 350 |
+
in_proj_bias: Optional[Tensor],
|
| 351 |
+
dropout_p: float,
|
| 352 |
+
out_proj_weight: Tensor,
|
| 353 |
+
out_proj_bias: Optional[Tensor],
|
| 354 |
+
training: bool = True,
|
| 355 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 356 |
+
pairwise_features: Optional[Tensor] = None,
|
| 357 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 358 |
+
r"""
|
| 359 |
+
Args:
|
| 360 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
| 361 |
+
See "Attention Is All You Need" for more details.
|
| 362 |
+
embed_dim_to_check: total dimension of the model.
|
| 363 |
+
num_heads: parallel attention heads.
|
| 364 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
| 365 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
| 366 |
+
add_zero_attn: add a new batch of zeros to the key and
|
| 367 |
+
value sequences at dim=1.
|
| 368 |
+
dropout_p: probability of an element to be zeroed.
|
| 369 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
| 370 |
+
training: apply dropout if is ``True``.
|
| 371 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
| 372 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
| 373 |
+
the corresponding value on the attention layer will be filled with -inf.
|
| 374 |
+
pairwise_features: If provided, include this in the MHA
|
| 375 |
+
Shape:
|
| 376 |
+
Inputs:
|
| 377 |
+
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
| 378 |
+
the embedding dimension.
|
| 379 |
+
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
| 380 |
+
the embedding dimension.
|
| 381 |
+
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
| 382 |
+
the embedding dimension.
|
| 383 |
+
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
| 384 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
| 385 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
| 386 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
| 387 |
+
Outputs:
|
| 388 |
+
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
| 389 |
+
E is the embedding dimension.
|
| 390 |
+
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
|
| 391 |
+
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
| 392 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
| 393 |
+
:math:`S` is the source sequence length. If ``average_weights=False``, returns attention weights per
|
| 394 |
+
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
# set up shape vars
|
| 398 |
+
tgt_len, bsz, embed_dim = query.shape
|
| 399 |
+
src_len, _, _ = key.shape
|
| 400 |
+
assert (
|
| 401 |
+
embed_dim == embed_dim_to_check
|
| 402 |
+
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
| 403 |
+
if isinstance(embed_dim, torch.Tensor):
|
| 404 |
+
# embed_dim can be a tensor when JIT tracing
|
| 405 |
+
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
| 406 |
+
else:
|
| 407 |
+
head_dim = embed_dim // num_heads
|
| 408 |
+
assert (
|
| 409 |
+
head_dim * num_heads == embed_dim
|
| 410 |
+
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
| 411 |
+
assert (
|
| 412 |
+
key.shape == value.shape
|
| 413 |
+
), f"key shape {key.shape} does not match value shape {value.shape}"
|
| 414 |
+
|
| 415 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
| 416 |
+
|
| 417 |
+
#
|
| 418 |
+
# reshape q, k, v for multihead attention and make em batch first
|
| 419 |
+
#
|
| 420 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
| 421 |
+
k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
| 422 |
+
v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
| 423 |
+
|
| 424 |
+
if pairwise_features is not None:
|
| 425 |
+
# Expand pairwise features, which should have dimension the size of
|
| 426 |
+
# the attn head dim
|
| 427 |
+
# B x L x L x H => L x L x (B*Nh) x (H/nh)
|
| 428 |
+
pairwise_features = pairwise_features.permute(1, 2, 0, 3).contiguous()
|
| 429 |
+
pairwise_features = pairwise_features.view(
|
| 430 |
+
tgt_len, tgt_len, bsz * num_heads, head_dim
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# L x L x (B*Nh) x (H/nh) => (B*Nh) x L x L x (H / Nh)
|
| 434 |
+
pairwise_features = pairwise_features.permute(2, 0, 1, 3)
|
| 435 |
+
|
| 436 |
+
# Uncomment if we project into hidden dim only
|
| 437 |
+
# pairwise_features = pairwise_features.repeat_interleave(self.num_heads, 0)
|
| 438 |
+
|
| 439 |
+
# update source sequence length after adjustments
|
| 440 |
+
src_len = k.size(1)
|
| 441 |
+
|
| 442 |
+
# merge key padding and attention masks
|
| 443 |
+
attn_mask = None
|
| 444 |
+
if key_padding_mask is not None:
|
| 445 |
+
assert key_padding_mask.shape == (
|
| 446 |
+
bsz,
|
| 447 |
+
src_len,
|
| 448 |
+
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
| 449 |
+
key_padding_mask = (
|
| 450 |
+
key_padding_mask.view(bsz, 1, 1, src_len)
|
| 451 |
+
.expand(-1, num_heads, -1, -1)
|
| 452 |
+
.reshape(bsz * num_heads, 1, src_len)
|
| 453 |
+
)
|
| 454 |
+
attn_mask = key_padding_mask
|
| 455 |
+
assert attn_mask.dtype == torch.bool
|
| 456 |
+
|
| 457 |
+
# adjust dropout probability
|
| 458 |
+
if not training:
|
| 459 |
+
dropout_p = 0.0
|
| 460 |
+
|
| 461 |
+
#
|
| 462 |
+
# calculate attention and out projection
|
| 463 |
+
#
|
| 464 |
+
if self.additive_attn:
|
| 465 |
+
attn_output, attn_output_weights = self._additive_attn(
|
| 466 |
+
q, k, v, attn_mask, dropout_p, pairwise_features=pairwise_features
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
attn_output, attn_output_weights = self._scaled_dot_product_attention(
|
| 470 |
+
q, k, v, attn_mask, dropout_p, pairwise_features=pairwise_features
|
| 471 |
+
)
|
| 472 |
+
# Editing
|
| 473 |
+
attn_output = (
|
| 474 |
+
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
| 475 |
+
)
|
| 476 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
| 477 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
| 478 |
+
|
| 479 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
| 480 |
+
return attn_output, attn_output_weights
|
| 481 |
+
|
| 482 |
+
def _additive_attn(
|
| 483 |
+
self,
|
| 484 |
+
q: Tensor,
|
| 485 |
+
k: Tensor,
|
| 486 |
+
v: Tensor,
|
| 487 |
+
attn_mask: Optional[Tensor] = None,
|
| 488 |
+
dropout_p: float = 0.0,
|
| 489 |
+
pairwise_features: Optional[Tensor] = None,
|
| 490 |
+
) -> Tuple[Tensor, Tensor]:
|
| 491 |
+
"""_additive_attn.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
q (Tensor): q
|
| 495 |
+
k (Tensor): k
|
| 496 |
+
v (Tensor): v
|
| 497 |
+
attn_mask (Optional[Tensor]): attn_mask
|
| 498 |
+
dropout_p (float): dropout_p
|
| 499 |
+
pairwise_features (Optional[Tensor]): pairwise_features
|
| 500 |
+
|
| 501 |
+
Returns:
|
| 502 |
+
Tuple[Tensor, Tensor]:
|
| 503 |
+
"""
|
| 504 |
+
r"""
|
| 505 |
+
Computes scaled dot product attention on query, key and value tensors, using
|
| 506 |
+
an optional attention mask if passed, and applying dropout if a probability
|
| 507 |
+
greater than 0.0 is specified.
|
| 508 |
+
Returns a tensor pair containing attended values and attention weights.
|
| 509 |
+
Args:
|
| 510 |
+
q, k, v: query, key and value tensors. See Shape section for shape details.
|
| 511 |
+
attn_mask: optional tensor containing mask values to be added to calculated
|
| 512 |
+
attention. May be 2D or 3D; see Shape section for details.
|
| 513 |
+
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
|
| 514 |
+
pairwise_features: Optional tensor for pairwise
|
| 515 |
+
featurizations
|
| 516 |
+
Shape:
|
| 517 |
+
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
|
| 518 |
+
and E is embedding dimension.
|
| 519 |
+
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 520 |
+
and E is embedding dimension.
|
| 521 |
+
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 522 |
+
and E is embedding dimension.
|
| 523 |
+
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
|
| 524 |
+
shape :math:`(Nt, Ns)`.
|
| 525 |
+
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
|
| 526 |
+
have shape :math:`(B, Nt, Ns)`
|
| 527 |
+
"""
|
| 528 |
+
# NOTE: Consider removing position i attending to itself?
|
| 529 |
+
|
| 530 |
+
B, Nt, E = q.shape
|
| 531 |
+
# Need linear layer here :/
|
| 532 |
+
# B x Nt x E => B x Nt x Nt x E
|
| 533 |
+
q_expand = q[:, :, None, :].expand(B, Nt, Nt, E)
|
| 534 |
+
v_expand = v[:, None, :, :].expand(B, Nt, Nt, E)
|
| 535 |
+
# B x Nt x Nt x E => B x Nt x Nt x 2E
|
| 536 |
+
cat_ar = [q_expand, v_expand]
|
| 537 |
+
if pairwise_features is not None:
|
| 538 |
+
cat_ar.append(pairwise_features)
|
| 539 |
+
|
| 540 |
+
output = torch.cat(cat_ar, -1)
|
| 541 |
+
E_long = E * len(cat_ar)
|
| 542 |
+
|
| 543 |
+
output = output.view(-1, self.num_heads, Nt, Nt, E_long)
|
| 544 |
+
|
| 545 |
+
# B x Nt x Nt x len(cat_ar)*E => B x Nt x Nt x E
|
| 546 |
+
## This was a fixed attn weight for each head, now separating
|
| 547 |
+
# output = self.attn_weight_1(output)
|
| 548 |
+
output = torch.einsum("bnlwe,neh->bnlwh", output, self.attn_weight_1_weight)
|
| 549 |
+
|
| 550 |
+
output = output + self.attn_weight_1_bias[None, :, None, None, :]
|
| 551 |
+
|
| 552 |
+
output = F.leaky_relu(output)
|
| 553 |
+
|
| 554 |
+
# B x Nt x Nt x len(cat_ar)*E => B x Nt x Nt
|
| 555 |
+
# attn = self.attn_weight_2(output).squeeze()
|
| 556 |
+
attn = torch.einsum("bnlwh,nhi->bnlwi", output, self.attn_weight_2_weight)
|
| 557 |
+
attn = attn + self.attn_weight_2_bias[None, :, None, None, :]
|
| 558 |
+
attn = attn.contiguous().view(-1, Nt, Nt)
|
| 559 |
+
if attn_mask is not None:
|
| 560 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 561 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 562 |
+
attn += attn_mask
|
| 563 |
+
attn = F.softmax(attn, dim=-1)
|
| 564 |
+
output = torch.bmm(attn, v)
|
| 565 |
+
return output, attn
|
| 566 |
+
|
| 567 |
+
def _scaled_dot_product_attention(
|
| 568 |
+
self,
|
| 569 |
+
q: Tensor,
|
| 570 |
+
k: Tensor,
|
| 571 |
+
v: Tensor,
|
| 572 |
+
attn_mask: Optional[Tensor] = None,
|
| 573 |
+
dropout_p: float = 0.0,
|
| 574 |
+
pairwise_features: Optional[Tensor] = None,
|
| 575 |
+
) -> Tuple[Tensor, Tensor]:
|
| 576 |
+
r"""
|
| 577 |
+
Computes scaled dot product attention on query, key and value tensors, using
|
| 578 |
+
an optional attention mask if passed, and applying dropout if a probability
|
| 579 |
+
greater than 0.0 is specified.
|
| 580 |
+
Returns a tensor pair containing attended values and attention weights.
|
| 581 |
+
Args:
|
| 582 |
+
q, k, v: query, key and value tensors. See Shape section for shape details.
|
| 583 |
+
attn_mask: optional tensor containing mask values to be added to calculated
|
| 584 |
+
attention. May be 2D or 3D; see Shape section for details.
|
| 585 |
+
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
|
| 586 |
+
pairwise_features: Optional tensor for pairwise
|
| 587 |
+
featurizations
|
| 588 |
+
Shape:
|
| 589 |
+
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
|
| 590 |
+
and E is embedding dimension.
|
| 591 |
+
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 592 |
+
and E is embedding dimension.
|
| 593 |
+
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 594 |
+
and E is embedding dimension.
|
| 595 |
+
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
|
| 596 |
+
shape :math:`(Nt, Ns)`.
|
| 597 |
+
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
|
| 598 |
+
have shape :math:`(B, Nt, Ns)`
|
| 599 |
+
"""
|
| 600 |
+
B, Nt, E = q.shape
|
| 601 |
+
q = q / math.sqrt(E)
|
| 602 |
+
|
| 603 |
+
if self.pairwise_featurization:
|
| 604 |
+
## Inspired by Graph2Smiles and TransformerXL
|
| 605 |
+
# We use pairwise embedding / corrections
|
| 606 |
+
if pairwise_features is None:
|
| 607 |
+
raise ValueError()
|
| 608 |
+
|
| 609 |
+
# B*Nh x Nt x E => B x Nh x Nt x E
|
| 610 |
+
q = q.view(-1, self.num_heads, Nt, E)
|
| 611 |
+
q_1 = q + self.bias_u[None, :, None, :]
|
| 612 |
+
q_2 = q + self.bias_v[None, :, None, :]
|
| 613 |
+
|
| 614 |
+
# B x Nh x Nt x E => B*Nh x Nt x E
|
| 615 |
+
q_1 = q_1.view(-1, Nt, E)
|
| 616 |
+
q_2 = q_2.view(-1, Nt, E)
|
| 617 |
+
|
| 618 |
+
# B x Nh x Nt x E => B x Nh x Nt x Nt
|
| 619 |
+
a_c = torch.einsum("ble,bwe->blw", q_1, k)
|
| 620 |
+
|
| 621 |
+
# pairwise: B*Nh x Nt x Nt x E
|
| 622 |
+
# q_2: B*Nh x Nt x E
|
| 623 |
+
b_d = torch.einsum("ble,blwe->blw", q_2, pairwise_features)
|
| 624 |
+
|
| 625 |
+
attn = a_c + b_d
|
| 626 |
+
else:
|
| 627 |
+
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
| 628 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
| 629 |
+
|
| 630 |
+
if attn_mask is not None:
|
| 631 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 632 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 633 |
+
attn += attn_mask
|
| 634 |
+
|
| 635 |
+
attn = F.softmax(attn, dim=-1)
|
| 636 |
+
if dropout_p > 0.0:
|
| 637 |
+
attn = F.dropout(attn, p=dropout_p)
|
| 638 |
+
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
| 639 |
+
output = torch.bmm(attn, v)
|
| 640 |
+
return output, attn
|