diff --git "a/huggingface_wrapper.py" "b/huggingface_wrapper.py" --- "a/huggingface_wrapper.py" +++ "b/huggingface_wrapper.py" @@ -21,2113 +21,2115 @@ from torch import Tensor from torch.nn import Dropout, LayerNorm, Linear from transformers import PretrainedConfig, PreTrainedModel -UUID_URL_MAP = { - # global source models - "D72M9aEp": "https://zenodo.org/records/14908509/files/METL-G-20M-1D-D72M9aEp.pt?download=1", - "Nr9zCKpR": "https://zenodo.org/records/14908509/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1", - "auKdzzwX": "https://zenodo.org/records/14908509/files/METL-G-50M-1D-auKdzzwX.pt?download=1", - "6PSAzdfv": "https://zenodo.org/records/14908509/files/METL-G-50M-3D-6PSAzdfv.pt?download=1", - # local source models - "8gMPQJy4": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1", - "Hr4GNHws": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1", - "8iFoiYw2": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1", - "kt5DdWTa": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1", - "DMfkjVzT": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1", - "epegcFiH": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1", - "kS3rUS7h": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1", - "X7w83g6S": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1", - "UKebCQGz": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1", - "2rr8V4th": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1", - "PREhfC22": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1", - "9ASvszux": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1", - "HscFFkAb": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1", - "H48oiNZN": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1", - "CEMSx7ZC": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-PTEN-CEMSx7ZC.pt?download=1", - "PjxR5LW7": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-PTEN-PjxR5LW7.pt?download=1", - # metl bind source models - "K6mw24Rg": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1", - "Bo5wn2SG": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1", - # finetuned models from GFP design experiment - "YoQkzoLD": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1", - "PEkeRuxb": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1", -} - -IDENT_UUID_MAP = { - # the keys should be all lowercase - "metl-g-20m-1d": "D72M9aEp", - "metl-g-20m-3d": "Nr9zCKpR", - "metl-g-50m-1d": "auKdzzwX", - "metl-g-50m-3d": "6PSAzdfv", - # GFP local source models - "metl-l-2m-1d-gfp": "8gMPQJy4", - "metl-l-2m-3d-gfp": "Hr4GNHws", - # DLG4 local source models - "metl-l-2m-1d-dlg4_2022": "8iFoiYw2", - "metl-l-2m-3d-dlg4_2022": "kt5DdWTa", - # GB1 local source models - "metl-l-2m-1d-gb1": "DMfkjVzT", - "metl-l-2m-3d-gb1": "epegcFiH", - # GRB2 local source models - "metl-l-2m-1d-grb2": "kS3rUS7h", - "metl-l-2m-3d-grb2": "X7w83g6S", - # Pab1 local source models - "metl-l-2m-1d-pab1": "UKebCQGz", - "metl-l-2m-3d-pab1": "2rr8V4th", - # PTEN local source models - "metl-l-2m-1d-pten": "CEMSx7ZC", - "metl-l-2m-3d-pten": "PjxR5LW7", - # TEM-1 local source models - "metl-l-2m-1d-tem-1": "PREhfC22", - "metl-l-2m-3d-tem-1": "9ASvszux", - # Ube4b local source models - "metl-l-2m-1d-ube4b": "HscFFkAb", - "metl-l-2m-3d-ube4b": "H48oiNZN", - # METL-Bind for GB1 - "metl-bind-2m-3d-gb1-standard": "K6mw24Rg", - "metl-bind-2m-3d-gb1-binding": "Bo5wn2SG", - # GFP design models, giving them an ident - "metl-l-2m-1d-gfp-ft-design": "YoQkzoLD", - "metl-l-2m-3d-gfp-ft-design": "PEkeRuxb", -} - +""" implementation of transformer encoder with relative attention + references: + - https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a + - https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer + - https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py + - https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py +""" -def download_checkpoint(uuid): - ckpt = torch.hub.load_state_dict_from_url( - UUID_URL_MAP[uuid], map_location="cpu", file_name=f"{uuid}.pt" - ) - state_dict = ckpt["state_dict"] - hyper_parameters = ckpt["hyper_parameters"] - return state_dict, hyper_parameters +class RelativePosition3D(nn.Module): + """Contact map-based relative position embeddings""" + # need to compute a bucket_mtx for each structure + # need to know which bucket_mtx to use when grabbing the embeddings in forward() + # - on init, get a list of all PDB files we will be using + # - use a dictionary to store PDB files --> bucket_mtxs + # - forward() gets a new arg: the pdb file, which indexes into the dictionary to grab the right bucket_mtx + def __init__( + self, + embedding_len: int, + contact_threshold: int, + clipping_threshold: int, + pdb_fns: Optional[Union[str, list, tuple]] = None, + default_pdb_dir: str = "data/pdb_files", + ): -def _get_data_encoding(hparams): - if "encoding" in hparams and hparams["encoding"] == "int_seqs": - encoding = Encoding.INT_SEQS - elif "encoding" in hparams and hparams["encoding"] == "one_hot": - encoding = Encoding.ONE_HOT - elif ( - ("encoding" in hparams and hparams["encoding"] == "auto") - or "encoding" not in hparams - ) and hparams["model_name"] in ["transformer_encoder"]: - encoding = Encoding.INT_SEQS - else: - raise ValueError("Detected unsupported encoding in hyperparameters") + # preferably, pdb_fns contains full paths to the PDBs, but if just the PDB filename is given + # then it defaults to the path data/pdb_files/ + super().__init__() + self.embedding_len = embedding_len + self.clipping_threshold = clipping_threshold + self.contact_threshold = contact_threshold + self.default_pdb_dir = default_pdb_dir - return encoding + # dummy buffer for getting correct device for on-the-fly bucket matrix generation + self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) + # for 3D-based positions, the number of embeddings is generally the number of buckets + # for contact map-based distances, that is clipping_threshold + 1 + num_embeddings = clipping_threshold + 1 -def load_model_and_data_encoder(state_dict, hparams): - model = Model[hparams["model_name"]].cls(**hparams) - model.load_state_dict(state_dict) + # this is the embedding lookup table E_r + self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) - data_encoder = DataEncoder(_get_data_encoding(hparams)) + # set up pdb_fns that were passed in on init (can also be set up during runtime in forward()) + # todo: i'm using a hacky workaround to move the bucket_mtxs to the correct device + # i tried to make it more efficient by registering bucket matrices as buffers, but i was + # having problems with DDP syncing the buffers across processes + self.bucket_mtxs = {} + self.bucket_mtxs_device = self.dummy_buffer.device + self._init_pdbs(pdb_fns) - return model, data_encoder + def forward(self, pdb_fn): + # compute matrix R by grabbing the embeddings from the embeddings lookup table + embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn)) + return embeddings + # def _get_bucket_mtx(self, pdb_fn): + # """ retrieve a bucket matrix given the pdb_fn. + # if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be + # retrieved from the object buffer. if the bucket matrix has not been computed yet, it will be here """ + # pdb_attr = self._pdb_key(pdb_fn) + # if hasattr(self, pdb_attr): + # return getattr(self, pdb_attr) + # else: + # # encountering a new PDB at runtime... process it + # # todo: if there's a new PDB at runtime, it will be initialized separately in each instance + # # of RelativePosition3D, for each layer. It would be more efficient to have a global + # # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through + # self._init_pdb(pdb_fn) + # return getattr(self, pdb_attr) -def get_from_uuid(uuid): - if uuid in UUID_URL_MAP: - state_dict, hparams = download_checkpoint(uuid) - return load_model_and_data_encoder(state_dict, hparams) - else: - raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP") + def _move_bucket_mtxs(self, device): + for k, v in self.bucket_mtxs.items(): + self.bucket_mtxs[k] = v.to(device) + self.bucket_mtxs_device = device + def _get_bucket_mtx(self, pdb_fn): + """retrieve a bucket matrix given the pdb_fn. + if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be + retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly + """ -def get_from_ident(ident): - ident = ident.lower() - if ident in IDENT_UUID_MAP: - state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident]) - return load_model_and_data_encoder(state_dict, hparams) - else: - raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP") + # ensure that all the bucket matrices are on the same device as the nn.Embedding + if self.bucket_mtxs_device != self.dummy_buffer.device: + self._move_bucket_mtxs(self.dummy_buffer.device) + pdb_attr = self._pdb_key(pdb_fn) + if pdb_attr in self.bucket_mtxs: + return self.bucket_mtxs[pdb_attr] + else: + # encountering a new PDB at runtime... process it + # todo: if there's a new PDB at runtime, it will be initialized separately in each instance + # of RelativePosition3D, for each layer. It would be more efficient to have a global + # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through + self._init_pdb(pdb_fn) + return self.bucket_mtxs[pdb_attr] -def get_from_checkpoint(ckpt_fn): - ckpt = torch.load(ckpt_fn, map_location="cpu") - state_dict = ckpt["state_dict"] - hyper_parameters = ckpt["hyper_parameters"] - return load_model_and_data_encoder(state_dict, hyper_parameters) + # def _set_bucket_mtx(self, pdb_fn, bucket_mtx): + # """ store a bucket matrix as a buffer """ + # # if PyTorch ever implements a BufferDict, we could use it here efficiently + # # there is also BufferDict from https://botorch.org/api/_modules/botorch/utils/torch.html + # # would just need to modify it to have an option for persistent=False + # bucket_mtx = bucket_mtx.to(self.dummy_buffer.device) + # + # self.register_buffer(self._pdb_key(pdb_fn), bucket_mtx, persistent=False) + def _set_bucket_mtx(self, pdb_fn, bucket_mtx): + """store a bucket matrix in the bucket dict""" -def reset_parameters_helper(m: nn.Module): - """helper function for resetting model parameters, meant to be used with model.apply()""" + # move the bucket_mtx to the same device that the other bucket matrices are on + bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device) - # the PyTorch MultiHeadAttention has a private function _reset_parameters() - # other layers have a public reset_parameters()... go figure - reset_parameters = getattr(m, "reset_parameters", None) - reset_parameters_private = getattr(m, "_reset_parameters", None) + self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx - if callable(reset_parameters) and callable(reset_parameters_private): - raise RuntimeError( - "Module has both public and private methods for resetting parameters. " - "This is unexpected... probably should just call the public one." - ) + @staticmethod + def _pdb_key(pdb_fn): + """return a unique key for the given pdb_fn, used to map unique PDBs""" + # note this key does NOT currently support PDBs with the same basename but different paths + # assumes every PDB is in the format .pdb + # should be a compatible with being a class attribute, as it is used as a pytorch buffer name + return f"pdb_{basename(pdb_fn).split('.')[0]}" - if callable(reset_parameters): - m.reset_parameters() + def _init_pdbs(self, pdb_fns): + start = time.time() - if callable(reset_parameters_private): - m._reset_parameters() + if pdb_fns is None: + # nothing to initialize if pdb_fns is None + return + # make sure pdb_fns is a list + if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple): + pdb_fns = [pdb_fns] -class SequentialWithArgs(nn.Sequential): - def forward(self, x, **kwargs): - for module in self: - if isinstance(module, RelativeTransformerEncoder) or isinstance( - module, SequentialWithArgs - ): - # for relative transformer encoders, pass in kwargs (pdb_fn) - x = module(x, **kwargs) - else: - # for all modules, don't pass in kwargs - x = module(x) - return x + # init each pdb fn in the list + for pdb_fn in pdb_fns: + self._init_pdb(pdb_fn) + print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start)) -class PositionalEncoding(nn.Module): - # originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html - # they have since updated their implementation, but it is functionally equivalent - def __init__(self, d_model, dropout=0.1, max_len=5000): - super(PositionalEncoding, self).__init__() - self.dropout = nn.Dropout(p=dropout) - - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim] - # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first) - # fixed by changing pe = pe.unsqueeze(0).transpose(0, 1) to pe = pe.unsqueeze(0) - # also down below, changing our indexing into the position encoding to reflect new dimensions - # pe = pe.unsqueeze(0).transpose(0, 1) - pe = pe.unsqueeze(0) - self.register_buffer("pe", pe) - - def forward(self, x, **kwargs): - # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim] - # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first) - # fixed by changing x = x + self.pe[:x.size(0)] to x = x + self.pe[:, :x.size(1), :] - # x = x + self.pe[:x.size(0), :] - x = x + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class ScaledEmbedding(nn.Module): - # https://pytorch.org/tutorials/beginner/translation_transformer.html - # a helper function for embedding that scales by sqrt(d_model) in the forward() - # makes it, so we don't have to do the scaling in the main AttnModel forward() + def _init_pdb(self, pdb_fn): + """process a pdb file for use with structure-based relative attention""" + # if pdb_fn is not a full path, default to the path data/pdb_files/ + if dirname(pdb_fn) == "": + # handle the case where the pdb file is in the current working directory + # if there is a PDB file in the cwd.... then just use it as is. otherwise, append the default. + if not isfile(pdb_fn): + pdb_fn = join(self.default_pdb_dir, pdb_fn) - # todo: be aware of embedding scaling factor - # regarding the scaling factor, it's unclear exactly what the purpose is and whether it is needed - # there are several theories on why it is used, and it shows up in all the transformer reference implementations - # https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod - # 1. Has something to do with weight sharing between the embedding and the decoder output - # 2. Scales up the embeddings so the signal doesn't get overwhelmed when adding the absolute positional encoding - # 3. It cancels out with the scaling factor in scaled dot product attention, and helps make the model robust - # to the choice of embedding_len - # 4. It's not actually needed + # create a structure graph from the pdb_fn and contact threshold + cbeta_mtx = cbeta_distance_matrix(pdb_fn) + structure_graph = dist_thresh_graph(cbeta_mtx, self.contact_threshold) - # Regarding #1, not really sure about this. In section 3.4 of attention is all you need, - # that's where they state they multiply the embedding weights by sqrt(d_model), and the context is that they - # are sharing the same weight matrix between the two embedding layers and the pre-softmax linear transformation. - # there may be a reason that we want those weights scaled differently for the embedding layers vs. the linear - # transformation. It might have something to do with the scale at which embedding weights are initialized - # is more appropriate for the decoder linear transform vs how they are used in the attention function. Might have - # something to do with computing the correct next-token probabilities. Overall, I'm really not sure about this, - # but we aren't using a decoder anyway. So if this is the reason, then we don't need to perform the multiply. + # bucket_mtx indexes into the embedding lookup table to create the final distance matrix + bucket_mtx = self._compute_bucket_mtx(structure_graph) - # Regarding #2, it seems like in one implementation of transformers (fairseq), the sinusoidal positional encoding - # has a range of (-1.0, 1.0), but the word embedding are initialized with mean 0 and s.d embedding_dim ** -0.5, - # which for embedding_dim=512, is a range closer to (-0.10, 0.10). Thus, the positional embedding would overwhelm - # the word embeddings when they are added together. The scaling factor increases the signal of the word embeddings. - # for embedding_dim=512, it scales word embeddings by 22, increasing range of the word embeddings to (-2.2, 2.2). - # link to fairseq implementation, search for nn.init to see them do the initialization - # https://fairseq.readthedocs.io/en/v0.7.1/_modules/fairseq/models/transformer.html - # - # For PyTorch, PyTorch initializes nn.Embedding with a standard normal distribution mean 0, variance 1: N(0,1). - # this puts the range for the word embeddings around (-3, 3). the pytorch implementation for positional encoding - # also has a range of (-1.0, 1.0). So already, these are much closer in scale, and it doesn't seem like we need - # to increase the scale of the word embeddings. However, PyTorch example still multiply by the scaling factor - # unclear whether this is just a carryover that is not actually needed, or if there is a different reason - # - # EDIT! I just realized that even though nn.Embedding defaults to a range of around (-3, 3), the PyTorch - # transformer example actually re-initializes them using a uniform distribution in the range of (-0.1, 0.1) - # that makes it very similar to the fairseq implementation, so the scaling factor that PyTorch uses actually would - # bring the word embedding and positional encodings much closer in scale. So this could be the reason why pytorch - # does it + self._set_bucket_mtx(pdb_fn, bucket_mtx) - # Regarding #3, I don't think so. Firstly, does it actually cancel there? Secondly, the purpose of the scaling - # factor in scaled dot product attention, according to attention is all you need, is to counteract dot products - # that are very high in magnitude due to choice of large mbedding length (aka d_k). The problem with high magnitude - # dot products is that potentially, the softmax is pushed into regions where it has extremely small gradients, - # making learning difficult. If the scaling factor in the embedding was meant to counteract the scaling factor in - # scaled dot product attention, then what would be the point of doing all that? + def _compute_bucketed_neighbors(self, structure_graph, source_node): + """gets the bucketed neighbors from the given source node and structure graph""" + if self.clipping_threshold < 0: + raise ValueError("Clipping threshold must be >= 0") - # Regarding #4, I don't think the scaling will have any effects in practice, it's probably not needed + sspl = _inv_dict( + nx.single_source_shortest_path_length(structure_graph, source_node) + ) - # Overall, I think #2 is the most likely reason why this scaling is performed. In theory, I think - # even if the scaling wasn't performed, the network might learn to up-scale the word embedding weights to increase - # word embedding signal vs. the position signal on its own. Another question I have is why not just initialize - # the embedding weights to have higher initial values? Why put it in the range (-0.1, 0.1)? - # - # The fact that most implementations have this scaling concerns me, makes me think I might be missing something. - # For our purposes, we can train a couple models to see if scaling has any positive or negative effect. - # Still need to think about potential effects of this scaling on relative position embeddings. + if self.clipping_threshold is not None: + num_buckets = 1 + self.clipping_threshold + sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1) - def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool): - super(ScaledEmbedding, self).__init__() - self.embedding = nn.Embedding(num_embeddings, embedding_dim) - self.emb_size = embedding_dim - self.embed_scale = math.sqrt(self.emb_size) + return sspl - self.scale = scale + def _compute_bucket_mtx(self, structure_graph): + """get the bucket_mtx for the given structure_graph + calls _get_bucketed_neighbors for every node in the structure_graph""" + num_residues = len(list(structure_graph)) - self.init_weights() + # index into the embedding lookup table to create the final distance matrix + bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long) - def init_weights(self): - # todo: not sure why PyTorch example initializes weights like this - # might have something to do with word embedding scaling factor (see above) - # could also just try the default weight initialization for nn.Embedding() - init_range = 0.1 - self.embedding.weight.data.uniform_(-init_range, init_range) + for node_num in sorted(list(structure_graph)): + bucketed_neighbors = self._compute_bucketed_neighbors( + structure_graph, node_num + ) - def forward(self, tokens: Tensor, **kwargs): - if self.scale: - return self.embedding(tokens.long()) * self.embed_scale - else: - return self.embedding(tokens.long()) + for bucket_num, neighbors in bucketed_neighbors.items(): + bucket_mtx[node_num, neighbors] = bucket_num + return bucket_mtx -class FCBlock(nn.Module): - """a fully connected block with options for batchnorm and dropout - can extend in the future with option for different activation, etc""" - def __init__( - self, - in_features: int, - num_hidden_nodes: int = 64, - use_batchnorm: bool = False, - use_layernorm: bool = False, - norm_before_activation: bool = False, - use_dropout: bool = False, - dropout_rate: float = 0.2, - activation: str = "relu", - ): +class RelativePosition(nn.Module): + """creates the embedding lookup table E_r and computes R + note this inherits from pl.LightningModule instead of nn.Module + makes it easier to access the device with `self.device` + might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property + """ + def __init__(self, embedding_len: int, clipping_threshold: int): + """ + embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead + clipping_threshold: the maximum relative position, referred to as k by Shaw et al. + """ super().__init__() + self.embedding_len = embedding_len + self.clipping_threshold = clipping_threshold + # for sequence-based distances, the number of embeddings is 2*k+1, where k is the clipping threshold + num_embeddings = 2 * clipping_threshold + 1 - if use_batchnorm and use_layernorm: - raise ValueError( - "Only one of use_batchnorm or use_layernorm can be set to True" - ) + # this is the embedding lookup table E_r + self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) - self.use_batchnorm = use_batchnorm - self.use_dropout = use_dropout - self.use_layernorm = use_layernorm - self.norm_before_activation = norm_before_activation + # for getting the correct device for range vectors in forward + self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) - self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes) + def forward(self, length_q, length_k): + # supports different length sequences, but in self-attention length_q and length_k are the same + range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device) + range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device) - self.activation = get_activation_fn(activation, functional=False) + # this sets up the standard sequence-based distance matrix for relative positions + # the current position is 0, positions to the right are +1, +2, etc, and to the left -1, -2, etc + distance_mat = range_vec_k[None, :] - range_vec_q[:, None] + distance_mat_clipped = torch.clamp( + distance_mat, -self.clipping_threshold, self.clipping_threshold + ) - if use_batchnorm: - self.norm = nn.BatchNorm1d(num_hidden_nodes) + # convert to indices, indexing into the embedding table + final_mat = (distance_mat_clipped + self.clipping_threshold).long() - if use_layernorm: - self.norm = nn.LayerNorm(num_hidden_nodes) + # compute matrix R by grabbing the embeddings from the embedding lookup table + embeddings = self.embeddings_table(final_mat) - if use_dropout: - self.dropout = nn.Dropout(p=dropout_rate) + return embeddings - def forward(self, x, **kwargs): - x = self.fc(x) - # norm can be before or after activation, using flag - if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation: - x = self.norm(x) +class RelativeMultiHeadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout, + pos_encoding, + clipping_threshold, + contact_threshold, + pdb_fns, + ): + """ + Multi-head attention with relative position embeddings. Input data should be in batch_first format. + :param embed_dim: aka d_model, aka hid_dim + :param num_heads: number of heads + :param dropout: how much dropout for scaled dot product attention - x = self.activation(x) + :param pos_encoding: what type of positional encoding to use, relative or relative3D + :param clipping_threshold: clipping threshold for relative position embedding + :param contact_threshold: for relative_3D, the threshold in angstroms for the contact map + :param pdb_fns: pdb file(s) to set up the relative position object - # batchnorm being applied after activation, there is some discussion on this online - if ( - self.use_batchnorm or self.use_layernorm - ) and not self.norm_before_activation: - x = self.norm(x) + """ + super().__init__() - # dropout being applied last - if self.use_dropout: - x = self.dropout(x) + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - return x + # model dimensions + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + # pos encoding stuff + self.pos_encoding = pos_encoding + self.clipping_threshold = clipping_threshold + self.contact_threshold = contact_threshold + if pdb_fns is not None and not isinstance(pdb_fns, list): + pdb_fns = [pdb_fns] + self.pdb_fns = pdb_fns -class TaskSpecificPredictionLayers(nn.Module): - """Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input - into a single output node. All num_tasks outputs are then concatenated into a single tensor. - """ + # relative position embeddings for use with keys and values + # Shaw et al. uses relative position information for both keys and values + # Huang et al. only uses it for the keys, which is probably enough + if pos_encoding == "relative": + self.relative_position_k = RelativePosition( + self.head_dim, self.clipping_threshold + ) + self.relative_position_v = RelativePosition( + self.head_dim, self.clipping_threshold + ) + elif pos_encoding == "relative_3D": + self.relative_position_k = RelativePosition3D( + self.head_dim, + self.contact_threshold, + self.clipping_threshold, + self.pdb_fns, + ) + self.relative_position_v = RelativePosition3D( + self.head_dim, + self.contact_threshold, + self.clipping_threshold, + self.pdb_fns, + ) + else: + raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding)) - # todo: the independent layers are run in sequence rather than in parallel, causing a slowdown that - # scales with the number of tasks. might be able to run in parallel by hacking convolution operation - # https://stackoverflow.com/questions/58374980/run-multiple-models-of-an-ensemble-in-parallel-with-pytorch - # https://github.com/pytorch/pytorch/issues/54147 - # https://github.com/pytorch/pytorch/issues/36459 + # WQ, WK, and WV from attention is all you need + # note these default to bias=True, same as PyTorch implementation + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) - def __init__( - self, - num_tasks: int, - in_features: int, - num_hidden_nodes: int = 64, - use_batchnorm: bool = False, - use_dropout: bool = False, - dropout_rate: float = 0.2, - activation: str = "relu", - ): + # WO from attention is all you need + # used for the final projection when computing multi-head attention + # PyTorch uses NonDynamicallyQuantizableLinear instead of Linear to avoid triggering an obscure + # error quantizing the model https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L122 + # todo: if quantizing the model, explore if the above is a concern for us + self.out_proj = nn.Linear(embed_dim, embed_dim) - super().__init__() + # dropout for scaled dot product attention + self.dropout = nn.Dropout(dropout) - # each task-specific layer outputs a single node, - # which can be combined with torch.cat into prediction vector - self.task_specific_pred_layers = nn.ModuleList() - for i in range(num_tasks): - layers = [ - FCBlock( - in_features=in_features, - num_hidden_nodes=num_hidden_nodes, - use_batchnorm=use_batchnorm, - use_dropout=use_dropout, - dropout_rate=dropout_rate, - activation=activation, - ), - nn.Linear(in_features=num_hidden_nodes, out_features=1), - ] - self.task_specific_pred_layers.append(nn.Sequential(*layers)) + # scaling factor for scaled dot product attention + scale = torch.sqrt(torch.FloatTensor([self.head_dim])) + # persistent=False if you don't want to save it inside state_dict + self.register_buffer("scale", scale) - def forward(self, x, **kwargs): - # run each task-specific layer and concatenate outputs into a single output vector - task_specific_outputs = [] - for layer in self.task_specific_pred_layers: - task_specific_outputs.append(layer(x)) + # toggles meant to be set directly by user + self.need_weights = False + self.average_attn_weights = True - output = torch.cat(task_specific_outputs, dim=1) - return output + def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn): + """computes the attention weights (a "compatability function" of queries with corresponding keys)""" + # calculate the first term in the numerator attn1, which is Q*K + # todo: pytorch reshapes q,k and v to 3 dimensions (similar to how r_q2 is below) + # is that functionally equivalent to what we're doing? is their way faster? + # r_q1 = [batch_size, num_heads, len_q, head_dim] + r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute( + 0, 2, 1, 3 + ) + # todo: we could directly permute r_k1 to [batch_size, num_heads, head_dim, len_k] + # to make it compatible for matrix multiplication with r_q1, instead of 2-step approach + # r_k1 = [batch_size, num_heads, len_k, head_dim] + r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute( + 0, 2, 1, 3 + ) + # attn1 = [batch_size, num_heads, len_q, len_k] + attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) -class GlobalAveragePooling(nn.Module): - """helper class for global average pooling""" + # calculate the second term in the numerator attn2, which is Q*R + # r_q2 = [query_len, batch_size * num_heads, head_dim] + r_q2 = ( + query.permute(1, 0, 2) + .contiguous() + .view(len_q, batch_size * self.num_heads, self.head_dim) + ) - def __init__(self, dim=1): - super().__init__() - # our data is in [batch_size, sequence_length, embedding_length] - # with global pooling, we want to pool over the sequence dimension (dim=1) - self.dim = dim + # todo: support multiple different PDB base structures per batch + # one option: + # - require batches to be all the same protein + # - add argument to forward() to accept the PDB file for the protein in the batch + # - then we just pass in the PDB file to relative position's forward() + # to support multiple different structures per batch: + # - add argument to forward() to accept PDB files, one for each item in batch + # - make corresponding changing in relative_position object to return R for each structure + # - note: if there are a lot of of different structures, and the sequence lengths are long, + # this could be memory prohibitive because R (rel_pos_k) can take up a lot of mem for long seqs + # - adjust the attn2 calculation to factor in the multiple different R matrices. + # the way to do this might have to be to do multiple matmuls, one for each each + # basically, would split up r_q2 into several matrices grouped by structure, and then + # multiply with corresponding R, then combine back into the exact same order of the original r_q2 + # note: this may be computationally intensive (splitting, more matrix muliplies, joining) + # another option would be to create views(?), repeating the different Rs so we can do a + # a matris multiply directly with r_q2 + # - would shapes be affected if there was padding in the queries, keys, values? - def forward(self, x, **kwargs): - return torch.mean(x, dim=self.dim) + if self.pos_encoding == "relative": + # rel_pos_k = [len_q, len_k, head_dim] + rel_pos_k = self.relative_position_k(len_q, len_k) + elif self.pos_encoding == "relative_3D": + # rel_pos_k = [sequence length (from PDB structure), head_dim] + rel_pos_k = self.relative_position_k(pdb_fn) + else: + raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) + # the matmul basically computes the dot product between each input position’s query vector and + # its corresponding relative position embeddings across all input sequences in the heads and batch + # attn2 = [batch_size * num_heads, len_q, len_k] + attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1) + # attn2 = [batch_size, num_heads, len_q, len_k] + attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k) -class CLSPooling(nn.Module): - """helper class for CLS token extraction""" + # calculate attention weights + attn_weights = (attn1 + attn2) / self.scale - def __init__(self, cls_position=0): - super().__init__() + # apply mask if given + if mask is not None: + # todo: pytorch uses float("-inf") instead of -1e10 + attn_weights = attn_weights.masked_fill(mask == 0, -1e10) - # the position of the CLS token in the sequence dimension - # currently, the CLS token is in the first position, but may move it to the last position - self.cls_position = cls_position + # softmax gives us attn_weights weights + attn_weights = torch.softmax(attn_weights, dim=-1) + # attn_weights = [batch_size, num_heads, len_q, len_k] + attn_weights = self.dropout(attn_weights) - def forward(self, x, **kwargs): - # assumes input is in [batch_size, sequence_len, embedding_len] - # thus sequence dimension is dimension 1 - return x[:, self.cls_position, :] + return attn_weights + def _compute_avg_val( + self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn + ): + # todo: add option to not factor in relative position embeddings in value calculation + # calculate the first term, the attn*values + # r_v1 = [batch_size, num_heads, len_v, head_dim] + r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute( + 0, 2, 1, 3 + ) + # avg1 = [batch_size, num_heads, len_q, head_dim] + avg1 = torch.matmul(attn_weights, r_v1) -class TransformerEncoderWrapper(nn.TransformerEncoder): - """wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters, - so each transformer encoder layer has a different initialization""" + # calculate the second term, the attn*R + # similar to how relative embeddings are factored in the attention weights calculation + if self.pos_encoding == "relative": + # rel_pos_v = [query_len, value_len, head_dim] + rel_pos_v = self.relative_position_v(len_q, len_v) + elif self.pos_encoding == "relative_3D": + # rel_pos_v = [sequence length (from PDB structure), head_dim] + rel_pos_v = self.relative_position_v(pdb_fn) + else: + raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) - # todo: PyTorch is changing its transformer API... check up on and see if there is a better way - def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): - super().__init__(encoder_layer, num_layers, norm) - if reset_params: - self.apply(reset_parameters_helper) + # r_attn_weights = [len_q, batch_size * num_heads, len_v] + r_attn_weights = ( + attn_weights.permute(2, 0, 1, 3) + .contiguous() + .view(len_q, batch_size * self.num_heads, len_k) + ) + avg2 = torch.matmul(r_attn_weights, rel_pos_v) + # avg2 = [batch_size, num_heads, len_q, head_dim] + avg2 = ( + avg2.transpose(0, 1) + .contiguous() + .view(batch_size, self.num_heads, len_q, self.head_dim) + ) + # calculate avg value + x = avg1 + avg2 # [batch_size, num_heads, len_q, head_dim] + x = x.permute( + 0, 2, 1, 3 + ).contiguous() # [batch_size, len_q, num_heads, head_dim] + # x = [batch_size, len_q, embed_dim] + x = x.view(batch_size, len_q, self.embed_dim) -class AttnModel(nn.Module): - # https://pytorch.org/tutorials/beginner/transformer_tutorial.html + return x - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) + def forward(self, query, key, value, pdb_fn=None, mask=None): + # query = [batch_size, q_len, embed_dim] + # key = [batch_size, k_len, embed_dim] + # value = [batch_size, v_en, embed_dim] + batch_size = query.shape[0] + len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1]) - parser.add_argument( - "--pos_encoding", - type=str, - default="absolute", - choices=["none", "absolute", "relative", "relative_3D"], - help="what type of positional encoding to use", - ) - parser.add_argument( - "--pos_encoding_dropout", - type=float, - default=0.1, - help="out much dropout to use in positional encoding, for pos_encoding==absolute", - ) - parser.add_argument( - "--clipping_threshold", - type=int, - default=3, - help="clipping threshold for relative position embedding, for relative and relative_3D", - ) - parser.add_argument( - "--contact_threshold", - type=int, - default=7, - help="threshold, in angstroms, for contact map, for relative_3D", - ) - parser.add_argument("--embedding_len", type=int, default=128) - parser.add_argument("--num_heads", type=int, default=2) - parser.add_argument("--num_hidden", type=int, default=64) - parser.add_argument("--num_enc_layers", type=int, default=2) - parser.add_argument("--enc_layer_dropout", type=float, default=0.1) - parser.add_argument( - "--use_final_encoder_norm", action="store_true", default=False + # in projection (multiply inputs by WQ, WK, WV) + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + # first compute the attention weights, then multiply with values + # attn = [batch size, num_heads, len_q, len_k] + attn_weights = self._compute_attn_weights( + query, key, len_q, len_k, batch_size, mask, pdb_fn ) - parser.add_argument( - "--global_average_pooling", action="store_true", default=False + # take weighted average of values (weighted by attention weights) + attn_output = self._compute_avg_val( + value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn ) - parser.add_argument("--cls_pooling", action="store_true", default=False) - parser.add_argument( - "--use_task_specific_layers", - action="store_true", - default=False, - help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer" - " if both flags are set", - ) - parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) - parser.add_argument( - "--use_final_hidden_layer", action="store_true", default=False - ) - parser.add_argument("--final_hidden_size", type=int, default=64) - parser.add_argument( - "--use_final_hidden_layer_norm", action="store_true", default=False - ) - parser.add_argument( - "--final_hidden_layer_norm_before_activation", - action="store_true", - default=False, - ) - parser.add_argument( - "--use_final_hidden_layer_dropout", action="store_true", default=False - ) - parser.add_argument( - "--final_hidden_layer_dropout_rate", type=float, default=0.2 - ) + # output projection + # attn_output = [batch_size, len_q, embed_dim] + attn_output = self.out_proj(attn_output) - parser.add_argument( - "--activation", - type=str, - default="relu", - help="activation function used for all activations in the network", - ) - return parser + if self.need_weights: + # return attention weights in addition to attention + # average the weights over the heads (to get overall attention) + # attn_weights = [batch_size, len_q, len_k] + if self.average_attn_weights: + attn_weights = attn_weights.sum(dim=1) / self.num_heads + return {"attn_output": attn_output, "attn_weights": attn_weights} + else: + return attn_output + + +class RelativeTransformerEncoderLayer(nn.Module): + """ + d_model: the number of expected features in the input (required). + nhead: the number of heads in the MultiHeadAttention models (required). + clipping_threshold: the clipping threshold for relative position embeddings + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + norm_first: if ``True``, layer norm is done prior to attention and feedforward + operations, respectively. Otherwise, it's done after. Default: ``False`` (after). + """ + + # this is some kind of torch jit compiling helper... will also ensure these values don't change + __constants__ = ["batch_first", "norm_first"] def __init__( self, - # data args - num_tasks: int, - aa_seq_len: int, - num_tokens: int, - # transformer encoder model args - pos_encoding: str = "absolute", - pos_encoding_dropout: float = 0.1, - clipping_threshold: int = 3, - contact_threshold: int = 7, - pdb_fns: List[str] = None, - embedding_len: int = 64, - num_heads: int = 2, - num_hidden: int = 64, - num_enc_layers: int = 2, - enc_layer_dropout: float = 0.1, - use_final_encoder_norm: bool = False, - # pooling to fixed-length representation - global_average_pooling: bool = True, - cls_pooling: bool = False, - # prediction layers - use_task_specific_layers: bool = False, - task_specific_hidden_nodes: int = 64, - use_final_hidden_layer: bool = False, - final_hidden_size: int = 64, - use_final_hidden_layer_norm: bool = False, - final_hidden_layer_norm_before_activation: bool = False, - use_final_hidden_layer_dropout: bool = False, - final_hidden_layer_dropout_rate: float = 0.2, - # activation function - activation: str = "relu", - *args, - **kwargs, - ): - - super().__init__() + d_model, + nhead, + pos_encoding="relative", + clipping_threshold=3, + contact_threshold=7, + pdb_fns=None, + dim_feedforward=2048, + dropout=0.1, + activation=F.relu, + layer_norm_eps=1e-5, + norm_first=False, + ) -> None: - # store embedding length for use in the forward function - self.embedding_len = embedding_len - self.aa_seq_len = aa_seq_len + self.batch_first = True - # build up layers - layers = collections.OrderedDict() + super(RelativeTransformerEncoderLayer, self).__init__() - # amino acid embedding - layers["embedder"] = ScaledEmbedding( - num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True + self.self_attn = RelativeMultiHeadAttention( + d_model, + nhead, + dropout, + pos_encoding, + clipping_threshold, + contact_threshold, + pdb_fns, ) - # absolute positional encoding - if pos_encoding == "absolute": - layers["pos_encoder"] = PositionalEncoding( - embedding_len, dropout=pos_encoding_dropout, max_len=512 - ) + # feed forward model + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model) - # transformer encoder layer for none or absolute positional encoding - if pos_encoding in ["none", "absolute"]: - encoder_layer = torch.nn.TransformerEncoderLayer( - d_model=embedding_len, - nhead=num_heads, - dim_feedforward=num_hidden, - dropout=enc_layer_dropout, - activation=get_activation_fn(activation), - norm_first=True, - batch_first=True, - ) + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) - # layer norm that is used after the transformer encoder layers - # if the norm_first is False, this is *redundant* and not needed - # but if norm_first is True, this can be used to normalize outputs from - # the transformer encoder before inputting to the final fully connected layer - encoder_norm = None - if use_final_encoder_norm: - encoder_norm = nn.LayerNorm(embedding_len) + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = get_activation_fn(activation) + else: + self.activation = activation - layers["tr_encoder"] = TransformerEncoderWrapper( - encoder_layer=encoder_layer, - num_layers=num_enc_layers, - norm=encoder_norm, - ) + def forward(self, src: Tensor, pdb_fn=None) -> Tensor: + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x)) + x = self.norm2(x + self._ff_block(x)) - # transformer encoder layer for relative position encoding - elif pos_encoding in ["relative", "relative_3D"]: - relative_encoder_layer = RelativeTransformerEncoderLayer( - d_model=embedding_len, - nhead=num_heads, - pos_encoding=pos_encoding, - clipping_threshold=clipping_threshold, - contact_threshold=contact_threshold, - pdb_fns=pdb_fns, - dim_feedforward=num_hidden, - dropout=enc_layer_dropout, - activation=get_activation_fn(activation), - norm_first=True, - ) + return x - encoder_norm = None - if use_final_encoder_norm: - encoder_norm = nn.LayerNorm(embedding_len) + # self-attention block + def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor: + x = self.self_attn(x, x, x, pdb_fn=pdb_fn) + if isinstance(x, dict): + # handle the case where we are returning attention weights + x = x["attn_output"] + return self.dropout1(x) - layers["tr_encoder"] = RelativeTransformerEncoder( - encoder_layer=relative_encoder_layer, - num_layers=num_enc_layers, - norm=encoder_norm, - ) + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) - # GLOBAL AVERAGE POOLING OR CLS TOKEN - # set up the layers and output shapes (i.e. input shapes for the pred layer) - if global_average_pooling: - # pool over the sequence dimension - layers["avg_pooling"] = GlobalAveragePooling(dim=1) - pred_layer_input_features = embedding_len - elif cls_pooling: - layers["cls_pooling"] = CLSPooling(cls_position=0) - pred_layer_input_features = embedding_len - else: - # no global average pooling or CLS token - # sequence dimension is still there, just flattened - layers["flatten"] = nn.Flatten() - pred_layer_input_features = embedding_len * aa_seq_len - # PREDICTION - if use_task_specific_layers: - # task specific prediction layers (nonlinear transform for each task) - layers["prediction"] = TaskSpecificPredictionLayers( - num_tasks=num_tasks, - in_features=pred_layer_input_features, - num_hidden_nodes=task_specific_hidden_nodes, - activation=activation, - ) - elif use_final_hidden_layer: - # combined prediction linear (linear transform for each task) - layers["fc1"] = FCBlock( - in_features=pred_layer_input_features, - num_hidden_nodes=final_hidden_size, - use_batchnorm=False, - use_layernorm=use_final_hidden_layer_norm, - norm_before_activation=final_hidden_layer_norm_before_activation, - use_dropout=use_final_hidden_layer_dropout, - dropout_rate=final_hidden_layer_dropout_rate, - activation=activation, - ) +class RelativeTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): + super(RelativeTransformerEncoder, self).__init__() + # using get_clones means all layers have the same initialization + # this is also a problem in PyTorch's TransformerEncoder implementation, which this is based on + # todo: PyTorch is changing its transformer API... check up on and see if there is a better way + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm - layers["prediction"] = nn.Linear( - in_features=final_hidden_size, out_features=num_tasks - ) - else: - layers["prediction"] = nn.Linear( - in_features=pred_layer_input_features, out_features=num_tasks - ) + # important because get_clones means all layers have same initialization + # should recursively reset parameters for all submodules + if reset_params: + self.apply(reset_parameters_helper) - # FINAL MODEL - self.model = SequentialWithArgs(layers) + def forward(self, src: Tensor, pdb_fn=None) -> Tensor: + output = src - def forward(self, x, **kwargs): - return self.model(x, **kwargs) + for mod in self.layers: + output = mod(output, pdb_fn=pdb_fn) + if self.norm is not None: + output = self.norm(output) -class Transpose(nn.Module): - """helper layer to swap data from (batch, seq, channels) to (batch, channels, seq) - used as a helper in the convolutional network which pytorch defaults to channels-first - """ + return output - def __init__(self, dims: Tuple[int, ...] = (1, 2)): - super().__init__() - self.dims = dims - def forward(self, x, **kwargs): - x = x.transpose(*self.dims).contiguous() - return x +def _get_clones(module, num_clones): + return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)]) -def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1): - return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1 +def _inv_dict(d): + """helper function for contact map-based position embeddings""" + inv = dict() + for k, v in d.items(): + # collect dict keys into lists based on value + inv.setdefault(v, list()).append(k) + for k, v in inv.items(): + # put in sorted order + inv[k] = sorted(v) + return inv -class ConvBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - dilation: int = 1, - padding: str = "same", - use_batchnorm: bool = False, - use_layernorm: bool = False, - norm_before_activation: bool = False, - use_dropout: bool = False, - dropout_rate: float = 0.2, - activation: str = "relu", - ): +def _combine_d(d, threshold, combined_key): + """helper function for contact map-based position embeddings + d is a dictionary with ints as keys and lists as values. + for all keys >= threshold, this function combines the values of those keys into a single list + """ + out_d = {} + for k, v in d.items(): + if k < threshold: + out_d[k] = v + elif k >= threshold: + if combined_key not in out_d: + out_d[combined_key] = v + else: + out_d[combined_key] += v + if combined_key in out_d: + out_d[combined_key] = sorted(out_d[combined_key]) + return out_d - super().__init__() - if use_batchnorm and use_layernorm: - raise ValueError( - "Only one of use_batchnorm or use_layernorm can be set to True" - ) +""" Encodes data in different formats """ - self.use_batchnorm = use_batchnorm - self.use_layernorm = use_layernorm - self.norm_before_activation = norm_before_activation - self.use_dropout = use_dropout - self.conv = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - ) +class Encoding(Enum): + INT_SEQS = auto() + ONE_HOT = auto() + + +class DataEncoder: + chars = [ + "*", + "A", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "K", + "L", + "M", + "N", + "P", + "Q", + "R", + "S", + "T", + "V", + "W", + "Y", + ] + num_chars = len(chars) + mapping = {c: i for i, c in enumerate(chars)} + + def __init__(self, encoding: Encoding = Encoding.INT_SEQS): + self.encoding = encoding + + def _encode_from_int_seqs(self, seq_ints): + if self.encoding == Encoding.INT_SEQS: + return seq_ints + elif self.encoding == Encoding.ONE_HOT: + one_hot = np.eye(self.num_chars)[seq_ints] + return one_hot.astype(np.float32) - self.activation = get_activation_fn(activation, functional=False) + def encode_sequences(self, char_seqs): + seq_ints = [] + for char_seq in char_seqs: + int_seq = [self.mapping[c] for c in char_seq] + seq_ints.append(int_seq) + seq_ints = np.array(seq_ints).astype(int) + return self._encode_from_int_seqs(seq_ints) - if use_batchnorm: - self.norm = nn.BatchNorm1d(out_channels) + def encode_variants(self, wt, variants): + # convert wild type seq to integer encoding + wt_int = np.zeros(len(wt), dtype=np.uint8) + for i, c in enumerate(wt): + wt_int[i] = self.mapping[c] - if use_layernorm: - self.norm = nn.LayerNorm(out_channels) + # tile the wild-type seq + seq_ints = np.tile(wt_int, (len(variants), 1)) - if use_dropout: - self.dropout = nn.Dropout(p=dropout_rate) + for i, variant in enumerate(variants): + # special handling if we want to encode the wild-type seq (it's already correct!) + if variant == "_wt": + continue - def forward(self, x, **kwargs): - x = self.conv(x) + # variants are a list of mutations [mutation1, mutation2, ....] + variant = variant.split(",") + for mutation in variant: + # mutations are in the form + position = int(mutation[1:-1]) + replacement = self.mapping[mutation[-1]] + seq_ints[i, position] = replacement - # norm can be before or after activation, using flag - if self.use_batchnorm and self.norm_before_activation: - x = self.norm(x) - elif self.use_layernorm and self.norm_before_activation: - x = self.norm(x.transpose(1, 2)).transpose(1, 2) + seq_ints = seq_ints.astype(int) + return self._encode_from_int_seqs(seq_ints) - x = self.activation(x) - # batchnorm being applied after activation, there is some discussion on this online - if self.use_batchnorm and not self.norm_before_activation: - x = self.norm(x) - elif self.use_layernorm and not self.norm_before_activation: - x = self.norm(x.transpose(1, 2)).transpose(1, 2) +def reset_parameters_helper(m: nn.Module): + """helper function for resetting model parameters, meant to be used with model.apply()""" - # dropout being applied after batchnorm, there is some discussion on this online - if self.use_dropout: - x = self.dropout(x) + # the PyTorch MultiHeadAttention has a private function _reset_parameters() + # other layers have a public reset_parameters()... go figure + reset_parameters = getattr(m, "reset_parameters", None) + reset_parameters_private = getattr(m, "_reset_parameters", None) - return x + if callable(reset_parameters) and callable(reset_parameters_private): + raise RuntimeError( + "Module has both public and private methods for resetting parameters. " + "This is unexpected... probably should just call the public one." + ) + if callable(reset_parameters): + m.reset_parameters() -class ConvModel2(nn.Module): - """convolutional source model that supports padded inputs, pooling, etc""" + if callable(reset_parameters_private): + m._reset_parameters() - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--use_embedding", action="store_true", default=False) - parser.add_argument("--embedding_len", type=int, default=128) - parser.add_argument("--num_conv_layers", type=int, default=1) - parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) - parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) - parser.add_argument("--dilations", type=int, nargs="+", default=[1]) - parser.add_argument( - "--padding", type=str, default="valid", choices=["valid", "same"] - ) - parser.add_argument("--use_conv_layer_norm", action="store_true", default=False) - parser.add_argument( - "--conv_layer_norm_before_activation", action="store_true", default=False - ) - parser.add_argument( - "--use_conv_layer_dropout", action="store_true", default=False - ) - parser.add_argument("--conv_layer_dropout_rate", type=float, default=0.2) +class SequentialWithArgs(nn.Sequential): + def forward(self, x, **kwargs): + for module in self: + if isinstance(module, RelativeTransformerEncoder) or isinstance( + module, SequentialWithArgs + ): + # for relative transformer encoders, pass in kwargs (pdb_fn) + x = module(x, **kwargs) + else: + # for all modules, don't pass in kwargs + x = module(x) + return x - parser.add_argument( - "--global_average_pooling", action="store_true", default=False - ) - parser.add_argument( - "--use_task_specific_layers", action="store_true", default=False - ) - parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) - parser.add_argument( - "--use_final_hidden_layer", action="store_true", default=False - ) - parser.add_argument("--final_hidden_size", type=int, default=64) - parser.add_argument( - "--use_final_hidden_layer_norm", action="store_true", default=False - ) - parser.add_argument( - "--final_hidden_layer_norm_before_activation", - action="store_true", - default=False, - ) - parser.add_argument( - "--use_final_hidden_layer_dropout", action="store_true", default=False - ) - parser.add_argument( - "--final_hidden_layer_dropout_rate", type=float, default=0.2 - ) +class PositionalEncoding(nn.Module): + # originally from https://pytorch.org/tutorials/beginner/transformer_tutorial.html + # they have since updated their implementation, but it is functionally equivalent + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) - parser.add_argument( - "--activation", - type=str, - default="relu", - help="activation function used for all activations in the network", + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim] + # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first) + # fixed by changing pe = pe.unsqueeze(0).transpose(0, 1) to pe = pe.unsqueeze(0) + # also down below, changing our indexing into the position encoding to reflect new dimensions + # pe = pe.unsqueeze(0).transpose(0, 1) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) - return parser - - def __init__( - self, - # data - num_tasks: int, - aa_seq_len: int, - aa_encoding_len: int, - num_tokens: int, - # convolutional model args - use_embedding: bool = False, - embedding_len: int = 64, - num_conv_layers: int = 1, - kernel_sizes: List[int] = (7,), - out_channels: List[int] = (128,), - dilations: List[int] = (1,), - padding: str = "valid", - use_conv_layer_norm: bool = False, - conv_layer_norm_before_activation: bool = False, - use_conv_layer_dropout: bool = False, - conv_layer_dropout_rate: float = 0.2, - # pooling - global_average_pooling: bool = True, - # prediction layers - use_task_specific_layers: bool = False, - task_specific_hidden_nodes: int = 64, - use_final_hidden_layer: bool = False, - final_hidden_size: int = 64, - use_final_hidden_layer_norm: bool = False, - final_hidden_layer_norm_before_activation: bool = False, - use_final_hidden_layer_dropout: bool = False, - final_hidden_layer_dropout_rate: float = 0.2, - # activation function - activation: str = "relu", - *args, - **kwargs, - ): - - super(ConvModel2, self).__init__() + def forward(self, x, **kwargs): + # note the implementation on Pytorch's website expects [seq_len, batch_size, embedding_dim] + # however our data is in [batch_size, seq_len, embedding_dim] (i.e. batch_first) + # fixed by changing x = x + self.pe[:x.size(0)] to x = x + self.pe[:, :x.size(1), :] + # x = x + self.pe[:x.size(0), :] + x = x + self.pe[:, : x.size(1), :] + return self.dropout(x) - # build up the layers - layers = collections.OrderedDict() - # amino acid embedding - if use_embedding: - layers["embedder"] = ScaledEmbedding( - num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False - ) +class ScaledEmbedding(nn.Module): + # https://pytorch.org/tutorials/beginner/translation_transformer.html + # a helper function for embedding that scales by sqrt(d_model) in the forward() + # makes it, so we don't have to do the scaling in the main AttnModel forward() - # transpose the input to match PyTorch's expected format - layers["transpose"] = Transpose(dims=(1, 2)) + # todo: be aware of embedding scaling factor + # regarding the scaling factor, it's unclear exactly what the purpose is and whether it is needed + # there are several theories on why it is used, and it shows up in all the transformer reference implementations + # https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod + # 1. Has something to do with weight sharing between the embedding and the decoder output + # 2. Scales up the embeddings so the signal doesn't get overwhelmed when adding the absolute positional encoding + # 3. It cancels out with the scaling factor in scaled dot product attention, and helps make the model robust + # to the choice of embedding_len + # 4. It's not actually needed - # build up the convolutional layers - for layer_num in range(num_conv_layers): - # determine the number of input channels for the first convolutional layer - if layer_num == 0 and use_embedding: - # for the first convolutional layer, the in_channels is the embedding_len - in_channels = embedding_len - elif layer_num == 0 and not use_embedding: - # for the first convolutional layer, the in_channels is the aa_encoding_len - in_channels = aa_encoding_len - else: - in_channels = out_channels[layer_num - 1] + # Regarding #1, not really sure about this. In section 3.4 of attention is all you need, + # that's where they state they multiply the embedding weights by sqrt(d_model), and the context is that they + # are sharing the same weight matrix between the two embedding layers and the pre-softmax linear transformation. + # there may be a reason that we want those weights scaled differently for the embedding layers vs. the linear + # transformation. It might have something to do with the scale at which embedding weights are initialized + # is more appropriate for the decoder linear transform vs how they are used in the attention function. Might have + # something to do with computing the correct next-token probabilities. Overall, I'm really not sure about this, + # but we aren't using a decoder anyway. So if this is the reason, then we don't need to perform the multiply. - layers[f"conv{layer_num}"] = ConvBlock( - in_channels=in_channels, - out_channels=out_channels[layer_num], - kernel_size=kernel_sizes[layer_num], - dilation=dilations[layer_num], - padding=padding, - use_batchnorm=False, - use_layernorm=use_conv_layer_norm, - norm_before_activation=conv_layer_norm_before_activation, - use_dropout=use_conv_layer_dropout, - dropout_rate=conv_layer_dropout_rate, - activation=activation, - ) + # Regarding #2, it seems like in one implementation of transformers (fairseq), the sinusoidal positional encoding + # has a range of (-1.0, 1.0), but the word embedding are initialized with mean 0 and s.d embedding_dim ** -0.5, + # which for embedding_dim=512, is a range closer to (-0.10, 0.10). Thus, the positional embedding would overwhelm + # the word embeddings when they are added together. The scaling factor increases the signal of the word embeddings. + # for embedding_dim=512, it scales word embeddings by 22, increasing range of the word embeddings to (-2.2, 2.2). + # link to fairseq implementation, search for nn.init to see them do the initialization + # https://fairseq.readthedocs.io/en/v0.7.1/_modules/fairseq/models/transformer.html + # + # For PyTorch, PyTorch initializes nn.Embedding with a standard normal distribution mean 0, variance 1: N(0,1). + # this puts the range for the word embeddings around (-3, 3). the pytorch implementation for positional encoding + # also has a range of (-1.0, 1.0). So already, these are much closer in scale, and it doesn't seem like we need + # to increase the scale of the word embeddings. However, PyTorch example still multiply by the scaling factor + # unclear whether this is just a carryover that is not actually needed, or if there is a different reason + # + # EDIT! I just realized that even though nn.Embedding defaults to a range of around (-3, 3), the PyTorch + # transformer example actually re-initializes them using a uniform distribution in the range of (-0.1, 0.1) + # that makes it very similar to the fairseq implementation, so the scaling factor that PyTorch uses actually would + # bring the word embedding and positional encodings much closer in scale. So this could be the reason why pytorch + # does it - # handle transition from convolutional layers to fully connected layer - # either use global average pooling or flatten - # take into consideration whether we are using valid or same padding - if global_average_pooling: - # global average pooling (mean across the seq len dimension) - # the seq len dimensions is the last dimension (batch_size, num_filters, seq_len) - layers["avg_pooling"] = GlobalAveragePooling(dim=-1) - # the prediction layers will take num_filters input features - pred_layer_input_features = out_channels[-1] + # Regarding #3, I don't think so. Firstly, does it actually cancel there? Secondly, the purpose of the scaling + # factor in scaled dot product attention, according to attention is all you need, is to counteract dot products + # that are very high in magnitude due to choice of large mbedding length (aka d_k). The problem with high magnitude + # dot products is that potentially, the softmax is pushed into regions where it has extremely small gradients, + # making learning difficult. If the scaling factor in the embedding was meant to counteract the scaling factor in + # scaled dot product attention, then what would be the point of doing all that? - else: - # no global average pooling. flatten instead. - layers["flatten"] = nn.Flatten() - # calculate the final output len of the convolutional layers - # and the number of input features for the prediction layers - if padding == "valid": - # valid padding (aka no padding) results in shrinking length in progressive layers - conv_out_len = conv1d_out_shape( - aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0] - ) - for layer_num in range(1, num_conv_layers): - conv_out_len = conv1d_out_shape( - conv_out_len, - kernel_size=kernel_sizes[layer_num], - dilation=dilations[layer_num], - ) - pred_layer_input_features = conv_out_len * out_channels[-1] - else: - # padding == "same" - pred_layer_input_features = aa_seq_len * out_channels[-1] + # Regarding #4, I don't think the scaling will have any effects in practice, it's probably not needed - # prediction layer - if use_task_specific_layers: - layers["prediction"] = TaskSpecificPredictionLayers( - num_tasks=num_tasks, - in_features=pred_layer_input_features, - num_hidden_nodes=task_specific_hidden_nodes, - activation=activation, - ) + # Overall, I think #2 is the most likely reason why this scaling is performed. In theory, I think + # even if the scaling wasn't performed, the network might learn to up-scale the word embedding weights to increase + # word embedding signal vs. the position signal on its own. Another question I have is why not just initialize + # the embedding weights to have higher initial values? Why put it in the range (-0.1, 0.1)? + # + # The fact that most implementations have this scaling concerns me, makes me think I might be missing something. + # For our purposes, we can train a couple models to see if scaling has any positive or negative effect. + # Still need to think about potential effects of this scaling on relative position embeddings. - # final hidden layer (with potential additional dropout) - elif use_final_hidden_layer: - layers["fc1"] = FCBlock( - in_features=pred_layer_input_features, - num_hidden_nodes=final_hidden_size, - use_batchnorm=False, - use_layernorm=use_final_hidden_layer_norm, - norm_before_activation=final_hidden_layer_norm_before_activation, - use_dropout=use_final_hidden_layer_dropout, - dropout_rate=final_hidden_layer_dropout_rate, - activation=activation, - ) - layers["prediction"] = nn.Linear( - in_features=final_hidden_size, out_features=num_tasks - ) + def __init__(self, num_embeddings: int, embedding_dim: int, scale: bool): + super(ScaledEmbedding, self).__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + self.emb_size = embedding_dim + self.embed_scale = math.sqrt(self.emb_size) - else: - layers["prediction"] = nn.Linear( - in_features=pred_layer_input_features, out_features=num_tasks - ) + self.scale = scale - self.model = nn.Sequential(layers) + self.init_weights() - def forward(self, x, **kwargs): - output = self.model(x) - return output + def init_weights(self): + # todo: not sure why PyTorch example initializes weights like this + # might have something to do with word embedding scaling factor (see above) + # could also just try the default weight initialization for nn.Embedding() + init_range = 0.1 + self.embedding.weight.data.uniform_(-init_range, init_range) + def forward(self, tokens: Tensor, **kwargs): + if self.scale: + return self.embedding(tokens.long()) * self.embed_scale + else: + return self.embedding(tokens.long()) -class ConvModel(nn.Module): - """a convolutional network with convolutional layers followed by a fully connected layer""" - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--num_conv_layers", type=int, default=1) - parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) - parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) - parser.add_argument( - "--padding", type=str, default="valid", choices=["valid", "same"] - ) - parser.add_argument( - "--use_final_hidden_layer", - action="store_true", - help="whether to use a final hidden layer", - ) - parser.add_argument( - "--final_hidden_size", - type=int, - default=128, - help="number of nodes in the final hidden layer", - ) - parser.add_argument( - "--use_dropout", - action="store_true", - help="whether to use dropout in the final hidden layer", - ) - parser.add_argument( - "--dropout_rate", - type=float, - default=0.2, - help="dropout rate in the final hidden layer", - ) - parser.add_argument( - "--use_task_specific_layers", action="store_true", default=False - ) - parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) - return parser +class FCBlock(nn.Module): + """a fully connected block with options for batchnorm and dropout + can extend in the future with option for different activation, etc""" def __init__( self, - num_tasks: int, - aa_seq_len: int, - aa_encoding_len: int, - num_conv_layers: int = 1, - kernel_sizes: List[int] = (7,), - out_channels: List[int] = (128,), - padding: str = "valid", - use_final_hidden_layer: bool = True, - final_hidden_size: int = 128, + in_features: int, + num_hidden_nodes: int = 64, + use_batchnorm: bool = False, + use_layernorm: bool = False, + norm_before_activation: bool = False, use_dropout: bool = False, dropout_rate: float = 0.2, - use_task_specific_layers: bool = False, - task_specific_hidden_nodes: int = 64, - *args, - **kwargs, + activation: str = "relu", ): - super(ConvModel, self).__init__() + super().__init__() - # set up the model as a Sequential block (less to do in forward()) - layers = collections.OrderedDict() + if use_batchnorm and use_layernorm: + raise ValueError( + "Only one of use_batchnorm or use_layernorm can be set to True" + ) - layers["transpose"] = Transpose(dims=(1, 2)) + self.use_batchnorm = use_batchnorm + self.use_dropout = use_dropout + self.use_layernorm = use_layernorm + self.norm_before_activation = norm_before_activation - for layer_num in range(num_conv_layers): - # for the first convolutional layer, the in_channels is the feature_len - in_channels = ( - aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1] - ) + self.fc = nn.Linear(in_features=in_features, out_features=num_hidden_nodes) - layers["conv{}".format(layer_num)] = nn.Sequential( - nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels[layer_num], - kernel_size=kernel_sizes[layer_num], - padding=padding, - ), - nn.ReLU(), - ) + self.activation = get_activation_fn(activation, functional=False) - layers["flatten"] = nn.Flatten() + if use_batchnorm: + self.norm = nn.BatchNorm1d(num_hidden_nodes) - # calculate the final output len of the convolutional layers - # and the number of input features for the prediction layers - if padding == "valid": - # valid padding (aka no padding) results in shrinking length in progressive layers - conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0]) - for layer_num in range(1, num_conv_layers): - conv_out_len = conv1d_out_shape( - conv_out_len, kernel_size=kernel_sizes[layer_num] - ) - next_dim = conv_out_len * out_channels[-1] - elif padding == "same": - next_dim = aa_seq_len * out_channels[-1] - else: - raise ValueError("unexpected value for padding: {}".format(padding)) + if use_layernorm: + self.norm = nn.LayerNorm(num_hidden_nodes) - # final hidden layer (with potential additional dropout) - if use_final_hidden_layer: - layers["fc1"] = FCBlock( - in_features=next_dim, - num_hidden_nodes=final_hidden_size, - use_batchnorm=False, - use_dropout=use_dropout, - dropout_rate=dropout_rate, - ) - next_dim = final_hidden_size + if use_dropout: + self.dropout = nn.Dropout(p=dropout_rate) - # final prediction layer - # either task specific nonlinear layers or a single linear layer - if use_task_specific_layers: - layers["prediction"] = TaskSpecificPredictionLayers( - num_tasks=num_tasks, - in_features=next_dim, - num_hidden_nodes=task_specific_hidden_nodes, - ) - else: - layers["prediction"] = nn.Linear( - in_features=next_dim, out_features=num_tasks - ) + def forward(self, x, **kwargs): + x = self.fc(x) - self.model = nn.Sequential(layers) + # norm can be before or after activation, using flag + if (self.use_batchnorm or self.use_layernorm) and self.norm_before_activation: + x = self.norm(x) + + x = self.activation(x) + + # batchnorm being applied after activation, there is some discussion on this online + if ( + self.use_batchnorm or self.use_layernorm + ) and not self.norm_before_activation: + x = self.norm(x) + + # dropout being applied last + if self.use_dropout: + x = self.dropout(x) - def forward(self, x, **kwargs): - output = self.model(x) - return output + return x -class FCModel(nn.Module): +class TaskSpecificPredictionLayers(nn.Module): + """Constructs num_tasks [dense(num_hidden_nodes)+relu+dense(1)] layers, each independently transforming input + into a single output node. All num_tasks outputs are then concatenated into a single tensor. + """ - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--num_layers", type=int, default=1) - parser.add_argument("--num_hidden", nargs="+", type=int, default=[128]) - parser.add_argument("--use_batchnorm", action="store_true", default=False) - parser.add_argument("--use_layernorm", action="store_true", default=False) - parser.add_argument( - "--norm_before_activation", action="store_true", default=False - ) - parser.add_argument("--use_dropout", action="store_true", default=False) - parser.add_argument("--dropout_rate", type=float, default=0.2) - return parser + # todo: the independent layers are run in sequence rather than in parallel, causing a slowdown that + # scales with the number of tasks. might be able to run in parallel by hacking convolution operation + # https://stackoverflow.com/questions/58374980/run-multiple-models-of-an-ensemble-in-parallel-with-pytorch + # https://github.com/pytorch/pytorch/issues/54147 + # https://github.com/pytorch/pytorch/issues/36459 def __init__( self, num_tasks: int, - seq_encoding_len: int, - num_layers: int = 1, - num_hidden: List[int] = (128,), + in_features: int, + num_hidden_nodes: int = 64, use_batchnorm: bool = False, - use_layernorm: bool = False, - norm_before_activation: bool = False, use_dropout: bool = False, dropout_rate: float = 0.2, activation: str = "relu", - *args, - **kwargs, ): + super().__init__() - # set up the model as a Sequential block (less to do in forward()) - layers = collections.OrderedDict() + # each task-specific layer outputs a single node, + # which can be combined with torch.cat into prediction vector + self.task_specific_pred_layers = nn.ModuleList() + for i in range(num_tasks): + layers = [ + FCBlock( + in_features=in_features, + num_hidden_nodes=num_hidden_nodes, + use_batchnorm=use_batchnorm, + use_dropout=use_dropout, + dropout_rate=dropout_rate, + activation=activation, + ), + nn.Linear(in_features=num_hidden_nodes, out_features=1), + ] + self.task_specific_pred_layers.append(nn.Sequential(*layers)) - # flatten inputs as this is all fully connected - layers["flatten"] = nn.Flatten() + def forward(self, x, **kwargs): + # run each task-specific layer and concatenate outputs into a single output vector + task_specific_outputs = [] + for layer in self.task_specific_pred_layers: + task_specific_outputs.append(layer(x)) - # build up the variable number of hidden layers (fully connected + ReLU + dropout (if set)) - for layer_num in range(num_layers): - # for the first layer (layer_num == 0), in_features is determined by given input - # for subsequent layers, the in_features is the previous layer's num_hidden - in_features = ( - seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1] - ) + output = torch.cat(task_specific_outputs, dim=1) + return output - layers["fc{}".format(layer_num)] = FCBlock( - in_features=in_features, - num_hidden_nodes=num_hidden[layer_num], - use_batchnorm=use_batchnorm, - use_layernorm=use_layernorm, - norm_before_activation=norm_before_activation, - use_dropout=use_dropout, - dropout_rate=dropout_rate, - activation=activation, - ) - # finally, the linear output layer - in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len - layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks) +class GlobalAveragePooling(nn.Module): + """helper class for global average pooling""" - self.model = nn.Sequential(layers) + def __init__(self, dim=1): + super().__init__() + # our data is in [batch_size, sequence_length, embedding_length] + # with global pooling, we want to pool over the sequence dimension (dim=1) + self.dim = dim def forward(self, x, **kwargs): - output = self.model(x) - return output + return torch.mean(x, dim=self.dim) -class LRModel(nn.Module): - """a simple linear model""" +class CLSPooling(nn.Module): + """helper class for CLS token extraction""" - def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs): + def __init__(self, cls_position=0): super().__init__() - self.model = nn.Sequential( - nn.Flatten(), nn.Linear(seq_encoding_len, out_features=num_tasks) - ) + # the position of the CLS token in the sequence dimension + # currently, the CLS token is in the first position, but may move it to the last position + self.cls_position = cls_position def forward(self, x, **kwargs): - output = self.model(x) - return output + # assumes input is in [batch_size, sequence_len, embedding_len] + # thus sequence dimension is dimension 1 + return x[:, self.cls_position, :] -class TransferModel(nn.Module): - """transfer learning model""" +class TransformerEncoderWrapper(nn.TransformerEncoder): + """wrapper around PyTorch's TransformerEncoder that re-initializes layer parameters, + so each transformer encoder layer has a different initialization""" - @staticmethod - def add_model_specific_args(parent_parser): + # todo: PyTorch is changing its transformer API... check up on and see if there is a better way + def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): + super().__init__(encoder_layer, num_layers, norm) + if reset_params: + self.apply(reset_parameters_helper) - def none_or_int(value: str): - return None if value.lower() == "none" else int(value) - p = ArgumentParser(parents=[parent_parser], add_help=False) +class AttnModel(nn.Module): + # https://pytorch.org/tutorials/beginner/transformer_tutorial.html - # for model set up - p.add_argument("--pretrained_ckpt_path", type=str, default=None) + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) - # where to cut off the backbone - p.add_argument( - "--backbone_cutoff", - type=none_or_int, - default=-1, - help="where to cut off the backbone. can be a negative int, indexing back from " - "pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. " - "a value of -2 chops the prediction head and FC layer. a value of -3 chops" - "the above, as well as the global average pooling layer. all depends on architecture.", + parser.add_argument( + "--pos_encoding", + type=str, + default="absolute", + choices=["none", "absolute", "relative", "relative_3D"], + help="what type of positional encoding to use", ) - - p.add_argument( - "--pred_layer_input_features", + parser.add_argument( + "--pos_encoding_dropout", + type=float, + default=0.1, + help="out much dropout to use in positional encoding, for pos_encoding==absolute", + ) + parser.add_argument( + "--clipping_threshold", type=int, - default=None, - help="if None, number of features will be determined based on backbone_cutoff and standard " - "architecture. otherwise, specify the number of input features for the prediction layer", + default=3, + help="clipping threshold for relative position embedding, for relative and relative_3D", + ) + parser.add_argument( + "--contact_threshold", + type=int, + default=7, + help="threshold, in angstroms, for contact map, for relative_3D", + ) + parser.add_argument("--embedding_len", type=int, default=128) + parser.add_argument("--num_heads", type=int, default=2) + parser.add_argument("--num_hidden", type=int, default=64) + parser.add_argument("--num_enc_layers", type=int, default=2) + parser.add_argument("--enc_layer_dropout", type=float, default=0.1) + parser.add_argument( + "--use_final_encoder_norm", action="store_true", default=False ) - # top net args - p.add_argument( - "--top_net_type", - type=str, - default="linear", - choices=["linear", "nonlinear", "sklearn"], + parser.add_argument( + "--global_average_pooling", action="store_true", default=False ) - p.add_argument("--top_net_hidden_nodes", type=int, default=256) - p.add_argument("--top_net_use_batchnorm", action="store_true") - p.add_argument("--top_net_use_dropout", action="store_true") - p.add_argument("--top_net_dropout_rate", type=float, default=0.1) + parser.add_argument("--cls_pooling", action="store_true", default=False) - return p + parser.add_argument( + "--use_task_specific_layers", + action="store_true", + default=False, + help="exclusive with use_final_hidden_layer; takes priority over use_final_hidden_layer" + " if both flags are set", + ) + parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) + parser.add_argument( + "--use_final_hidden_layer", action="store_true", default=False + ) + parser.add_argument("--final_hidden_size", type=int, default=64) + parser.add_argument( + "--use_final_hidden_layer_norm", action="store_true", default=False + ) + parser.add_argument( + "--final_hidden_layer_norm_before_activation", + action="store_true", + default=False, + ) + parser.add_argument( + "--use_final_hidden_layer_dropout", action="store_true", default=False + ) + parser.add_argument( + "--final_hidden_layer_dropout_rate", type=float, default=0.2 + ) + + parser.add_argument( + "--activation", + type=str, + default="relu", + help="activation function used for all activations in the network", + ) + return parser def __init__( self, - # pretrained model - pretrained_ckpt_path: Optional[str] = None, - pretrained_hparams: Optional[dict] = None, - backbone_cutoff: Optional[int] = -1, - # top net - pred_layer_input_features: Optional[int] = None, - top_net_type: str = "linear", - top_net_hidden_nodes: int = 256, - top_net_use_batchnorm: bool = False, - top_net_use_dropout: bool = False, - top_net_dropout_rate: float = 0.1, + # data args + num_tasks: int, + aa_seq_len: int, + num_tokens: int, + # transformer encoder model args + pos_encoding: str = "absolute", + pos_encoding_dropout: float = 0.1, + clipping_threshold: int = 3, + contact_threshold: int = 7, + pdb_fns: List[str] = None, + embedding_len: int = 64, + num_heads: int = 2, + num_hidden: int = 64, + num_enc_layers: int = 2, + enc_layer_dropout: float = 0.1, + use_final_encoder_norm: bool = False, + # pooling to fixed-length representation + global_average_pooling: bool = True, + cls_pooling: bool = False, + # prediction layers + use_task_specific_layers: bool = False, + task_specific_hidden_nodes: int = 64, + use_final_hidden_layer: bool = False, + final_hidden_size: int = 64, + use_final_hidden_layer_norm: bool = False, + final_hidden_layer_norm_before_activation: bool = False, + use_final_hidden_layer_dropout: bool = False, + final_hidden_layer_dropout_rate: float = 0.2, + # activation function + activation: str = "relu", *args, **kwargs, ): super().__init__() - # error checking: if pretrained_ckpt_path is None, then pretrained_hparams must be specified - if pretrained_ckpt_path is None and pretrained_hparams is None: - raise ValueError( - "Either pretrained_ckpt_path or pretrained_hparams must be specified" - ) + # store embedding length for use in the forward function + self.embedding_len = embedding_len + self.aa_seq_len = aa_seq_len - # note: pdb_fns is loaded from transfer model arguments rather than original source model hparams - # if pdb_fns is specified as a kwarg, pass it on for structure-based RPE - # otherwise, can just set pdb_fns to None, and structure-based RPE will handle new PDBs on the fly - pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None + # build up layers + layers = collections.OrderedDict() - # generate a fresh backbone using pretrained_hparams if specified - # otherwise load the backbone from the pretrained checkpoint - # we prioritize pretrained_hparams over pretrained_ckpt_path because - # pretrained_hparams will only really be specified if we are loading from a DMSTask checkpoint - # meaning the TransferModel has already been fine-tuned on DMS data, and we are likely loading - # weights from that finetuning (including weights for the backbone) - # whereas if pretrained_hparams is not specified but pretrained_ckpt_path is, then we are - # likely finetuning the TransferModel for the first time, and we need the pretrained weights for the - # backbone from the RosettaTask checkpoint - if pretrained_hparams is not None: - # pretrained_hparams will only be specified if we are loading from a DMSTask checkpoint - pretrained_hparams["pdb_fns"] = pdb_fns - pretrained_model = Model[pretrained_hparams["model_name"]].cls( - **pretrained_hparams + # amino acid embedding + layers["embedder"] = ScaledEmbedding( + num_embeddings=num_tokens, embedding_dim=embedding_len, scale=True + ) + + # absolute positional encoding + if pos_encoding == "absolute": + layers["pos_encoder"] = PositionalEncoding( + embedding_len, dropout=pos_encoding_dropout, max_len=512 ) - self.pretrained_hparams = pretrained_hparams - else: - # not supported in metl-pretrained - raise NotImplementedError( - "Loading pretrained weights from RosettaTask checkpoint not supported" + + # transformer encoder layer for none or absolute positional encoding + if pos_encoding in ["none", "absolute"]: + encoder_layer = torch.nn.TransformerEncoderLayer( + d_model=embedding_len, + nhead=num_heads, + dim_feedforward=num_hidden, + dropout=enc_layer_dropout, + activation=get_activation_fn(activation), + norm_first=True, + batch_first=True, ) - layers = collections.OrderedDict() + # layer norm that is used after the transformer encoder layers + # if the norm_first is False, this is *redundant* and not needed + # but if norm_first is True, this can be used to normalize outputs from + # the transformer encoder before inputting to the final fully connected layer + encoder_norm = None + if use_final_encoder_norm: + encoder_norm = nn.LayerNorm(embedding_len) - # set the backbone to all layers except the last layer (the pre-trained prediction layer) - if backbone_cutoff is None: - layers["backbone"] = SequentialWithArgs( - *list(pretrained_model.model.children()) + layers["tr_encoder"] = TransformerEncoderWrapper( + encoder_layer=encoder_layer, + num_layers=num_enc_layers, + norm=encoder_norm, ) - else: - layers["backbone"] = SequentialWithArgs( - *list(pretrained_model.model.children())[0:backbone_cutoff] + + # transformer encoder layer for relative position encoding + elif pos_encoding in ["relative", "relative_3D"]: + relative_encoder_layer = RelativeTransformerEncoderLayer( + d_model=embedding_len, + nhead=num_heads, + pos_encoding=pos_encoding, + clipping_threshold=clipping_threshold, + contact_threshold=contact_threshold, + pdb_fns=pdb_fns, + dim_feedforward=num_hidden, + dropout=enc_layer_dropout, + activation=get_activation_fn(activation), + norm_first=True, ) - if top_net_type == "sklearn": - # sklearn top not doesn't require any more layers, just return model for the repr layer - self.model = SequentialWithArgs(layers) - return + encoder_norm = None + if use_final_encoder_norm: + encoder_norm = nn.LayerNorm(embedding_len) - # figure out dimensions of input into the prediction layer - if pred_layer_input_features is None: - # todo: can make this more robust by checking if the pretrained_mode.hparams for use_final_hidden_layer, - # global_average_pooling, etc. then can determine what the layer will be based on backbone_cutoff. - # currently, assumes that pretrained_model uses global average pooling and a final_hidden_layer - if backbone_cutoff is None: - # no backbone cutoff... use the full network (including tasks) as the backbone - pred_layer_input_features = self.pretrained_hparams["num_tasks"] - elif backbone_cutoff == -1: - pred_layer_input_features = self.pretrained_hparams["final_hidden_size"] - elif backbone_cutoff == -2: - pred_layer_input_features = self.pretrained_hparams["embedding_len"] - elif backbone_cutoff == -3: - pred_layer_input_features = ( - self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"] - ) - else: - raise ValueError( - "can't automatically determine pred_layer_input_features for given backbone_cutoff" - ) + layers["tr_encoder"] = RelativeTransformerEncoder( + encoder_layer=relative_encoder_layer, + num_layers=num_enc_layers, + norm=encoder_norm, + ) - layers["flatten"] = nn.Flatten(start_dim=1) + # GLOBAL AVERAGE POOLING OR CLS TOKEN + # set up the layers and output shapes (i.e. input shapes for the pred layer) + if global_average_pooling: + # pool over the sequence dimension + layers["avg_pooling"] = GlobalAveragePooling(dim=1) + pred_layer_input_features = embedding_len + elif cls_pooling: + layers["cls_pooling"] = CLSPooling(cls_position=0) + pred_layer_input_features = embedding_len + else: + # no global average pooling or CLS token + # sequence dimension is still there, just flattened + layers["flatten"] = nn.Flatten() + pred_layer_input_features = embedding_len * aa_seq_len - # create a new prediction layer on top of the backbone - if top_net_type == "linear": - # linear layer for prediction - layers["prediction"] = nn.Linear( - in_features=pred_layer_input_features, out_features=1 + # PREDICTION + if use_task_specific_layers: + # task specific prediction layers (nonlinear transform for each task) + layers["prediction"] = TaskSpecificPredictionLayers( + num_tasks=num_tasks, + in_features=pred_layer_input_features, + num_hidden_nodes=task_specific_hidden_nodes, + activation=activation, ) - elif top_net_type == "nonlinear": - # fully connected with hidden layer - fc_block = FCBlock( + elif use_final_hidden_layer: + # combined prediction linear (linear transform for each task) + layers["fc1"] = FCBlock( in_features=pred_layer_input_features, - num_hidden_nodes=top_net_hidden_nodes, - use_batchnorm=top_net_use_batchnorm, - use_dropout=top_net_use_dropout, - dropout_rate=top_net_dropout_rate, + num_hidden_nodes=final_hidden_size, + use_batchnorm=False, + use_layernorm=use_final_hidden_layer_norm, + norm_before_activation=final_hidden_layer_norm_before_activation, + use_dropout=use_final_hidden_layer_dropout, + dropout_rate=final_hidden_layer_dropout_rate, + activation=activation, ) - pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1) - - layers["prediction"] = SequentialWithArgs(fc_block, pred_layer) + layers["prediction"] = nn.Linear( + in_features=final_hidden_size, out_features=num_tasks + ) else: - raise ValueError( - "Unexpected type of top net layer: {}".format(top_net_type) + layers["prediction"] = nn.Linear( + in_features=pred_layer_input_features, out_features=num_tasks ) + # FINAL MODEL self.model = SequentialWithArgs(layers) def forward(self, x, **kwargs): return self.model(x, **kwargs) -def get_activation_fn(activation, functional=True): - if activation == "relu": - return F.relu if functional else nn.ReLU() - elif activation == "gelu": - return F.gelu if functional else nn.GELU() - elif activation == "silo" or activation == "swish": - return F.silu if functional else nn.SiLU() - elif activation == "leaky_relu" or activation == "lrelu": - return F.leaky_relu if functional else nn.LeakyReLU() - else: - raise RuntimeError("unknown activation: {}".format(activation)) - - -class Model(enum.Enum): - def __new__(cls, *args, **kwds): - value = len(cls.__members__) + 1 - obj = object.__new__(cls) - obj._value_ = value - return obj - - def __init__(self, cls, transfer_model): - self.cls = cls - self.transfer_model = transfer_model - - linear = LRModel, False - fully_connected = FCModel, False - cnn = ConvModel, False - cnn2 = ConvModel2, False - transformer_encoder = AttnModel, False - transfer_model = TransferModel, True - - -def main(): - pass - - -if __name__ == "__main__": - main() -""" implementation of transformer encoder with relative attention - references: - - https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a - - https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer - - https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py - - https://github.com/jiezouguihuafu/ClassicalModelreproduced/blob/main/Transformer/transfor_rpe.py -""" - - -class RelativePosition3D(nn.Module): - """Contact map-based relative position embeddings""" - - # need to compute a bucket_mtx for each structure - # need to know which bucket_mtx to use when grabbing the embeddings in forward() - # - on init, get a list of all PDB files we will be using - # - use a dictionary to store PDB files --> bucket_mtxs - # - forward() gets a new arg: the pdb file, which indexes into the dictionary to grab the right bucket_mtx - def __init__( - self, - embedding_len: int, - contact_threshold: int, - clipping_threshold: int, - pdb_fns: Optional[Union[str, list, tuple]] = None, - default_pdb_dir: str = "data/pdb_files", - ): +class Transpose(nn.Module): + """helper layer to swap data from (batch, seq, channels) to (batch, channels, seq) + used as a helper in the convolutional network which pytorch defaults to channels-first + """ - # preferably, pdb_fns contains full paths to the PDBs, but if just the PDB filename is given - # then it defaults to the path data/pdb_files/ + def __init__(self, dims: Tuple[int, ...] = (1, 2)): super().__init__() - self.embedding_len = embedding_len - self.clipping_threshold = clipping_threshold - self.contact_threshold = contact_threshold - self.default_pdb_dir = default_pdb_dir + self.dims = dims - # dummy buffer for getting correct device for on-the-fly bucket matrix generation - self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) + def forward(self, x, **kwargs): + x = x.transpose(*self.dims).contiguous() + return x - # for 3D-based positions, the number of embeddings is generally the number of buckets - # for contact map-based distances, that is clipping_threshold + 1 - num_embeddings = clipping_threshold + 1 - # this is the embedding lookup table E_r - self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) +def conv1d_out_shape(seq_len, kernel_size, stride=1, pad=0, dilation=1): + return (seq_len + (2 * pad) - (dilation * (kernel_size - 1)) - 1 // stride) + 1 - # set up pdb_fns that were passed in on init (can also be set up during runtime in forward()) - # todo: i'm using a hacky workaround to move the bucket_mtxs to the correct device - # i tried to make it more efficient by registering bucket matrices as buffers, but i was - # having problems with DDP syncing the buffers across processes - self.bucket_mtxs = {} - self.bucket_mtxs_device = self.dummy_buffer.device - self._init_pdbs(pdb_fns) - def forward(self, pdb_fn): - # compute matrix R by grabbing the embeddings from the embeddings lookup table - embeddings = self.embeddings_table(self._get_bucket_mtx(pdb_fn)) - return embeddings +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + dilation: int = 1, + padding: str = "same", + use_batchnorm: bool = False, + use_layernorm: bool = False, + norm_before_activation: bool = False, + use_dropout: bool = False, + dropout_rate: float = 0.2, + activation: str = "relu", + ): - # def _get_bucket_mtx(self, pdb_fn): - # """ retrieve a bucket matrix given the pdb_fn. - # if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be - # retrieved from the object buffer. if the bucket matrix has not been computed yet, it will be here """ - # pdb_attr = self._pdb_key(pdb_fn) - # if hasattr(self, pdb_attr): - # return getattr(self, pdb_attr) - # else: - # # encountering a new PDB at runtime... process it - # # todo: if there's a new PDB at runtime, it will be initialized separately in each instance - # # of RelativePosition3D, for each layer. It would be more efficient to have a global - # # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through - # self._init_pdb(pdb_fn) - # return getattr(self, pdb_attr) + super().__init__() - def _move_bucket_mtxs(self, device): - for k, v in self.bucket_mtxs.items(): - self.bucket_mtxs[k] = v.to(device) - self.bucket_mtxs_device = device + if use_batchnorm and use_layernorm: + raise ValueError( + "Only one of use_batchnorm or use_layernorm can be set to True" + ) - def _get_bucket_mtx(self, pdb_fn): - """retrieve a bucket matrix given the pdb_fn. - if the pdb_fn was provided at init or has already been computed, then the bucket matrix will be - retrieved from the bucket_mtxs dictionary. else, it will be computed now on-the-fly - """ + self.use_batchnorm = use_batchnorm + self.use_layernorm = use_layernorm + self.norm_before_activation = norm_before_activation + self.use_dropout = use_dropout - # ensure that all the bucket matrices are on the same device as the nn.Embedding - if self.bucket_mtxs_device != self.dummy_buffer.device: - self._move_bucket_mtxs(self.dummy_buffer.device) + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + ) - pdb_attr = self._pdb_key(pdb_fn) - if pdb_attr in self.bucket_mtxs: - return self.bucket_mtxs[pdb_attr] - else: - # encountering a new PDB at runtime... process it - # todo: if there's a new PDB at runtime, it will be initialized separately in each instance - # of RelativePosition3D, for each layer. It would be more efficient to have a global - # bucket_mtx registry... perhaps in the RelativeTransformerEncoder class, that can be passed through - self._init_pdb(pdb_fn) - return self.bucket_mtxs[pdb_attr] + self.activation = get_activation_fn(activation, functional=False) - # def _set_bucket_mtx(self, pdb_fn, bucket_mtx): - # """ store a bucket matrix as a buffer """ - # # if PyTorch ever implements a BufferDict, we could use it here efficiently - # # there is also BufferDict from https://botorch.org/api/_modules/botorch/utils/torch.html - # # would just need to modify it to have an option for persistent=False - # bucket_mtx = bucket_mtx.to(self.dummy_buffer.device) - # - # self.register_buffer(self._pdb_key(pdb_fn), bucket_mtx, persistent=False) + if use_batchnorm: + self.norm = nn.BatchNorm1d(out_channels) - def _set_bucket_mtx(self, pdb_fn, bucket_mtx): - """store a bucket matrix in the bucket dict""" + if use_layernorm: + self.norm = nn.LayerNorm(out_channels) - # move the bucket_mtx to the same device that the other bucket matrices are on - bucket_mtx = bucket_mtx.to(self.bucket_mtxs_device) + if use_dropout: + self.dropout = nn.Dropout(p=dropout_rate) - self.bucket_mtxs[self._pdb_key(pdb_fn)] = bucket_mtx + def forward(self, x, **kwargs): + x = self.conv(x) - @staticmethod - def _pdb_key(pdb_fn): - """return a unique key for the given pdb_fn, used to map unique PDBs""" - # note this key does NOT currently support PDBs with the same basename but different paths - # assumes every PDB is in the format .pdb - # should be a compatible with being a class attribute, as it is used as a pytorch buffer name - return f"pdb_{basename(pdb_fn).split('.')[0]}" + # norm can be before or after activation, using flag + if self.use_batchnorm and self.norm_before_activation: + x = self.norm(x) + elif self.use_layernorm and self.norm_before_activation: + x = self.norm(x.transpose(1, 2)).transpose(1, 2) - def _init_pdbs(self, pdb_fns): - start = time.time() + x = self.activation(x) - if pdb_fns is None: - # nothing to initialize if pdb_fns is None - return + # batchnorm being applied after activation, there is some discussion on this online + if self.use_batchnorm and not self.norm_before_activation: + x = self.norm(x) + elif self.use_layernorm and not self.norm_before_activation: + x = self.norm(x.transpose(1, 2)).transpose(1, 2) - # make sure pdb_fns is a list - if not isinstance(pdb_fns, list) and not isinstance(pdb_fns, tuple): - pdb_fns = [pdb_fns] + # dropout being applied after batchnorm, there is some discussion on this online + if self.use_dropout: + x = self.dropout(x) - # init each pdb fn in the list - for pdb_fn in pdb_fns: - self._init_pdb(pdb_fn) + return x - print("Initialized PDB bucket matrices in: {:.3f}".format(time.time() - start)) - def _init_pdb(self, pdb_fn): - """process a pdb file for use with structure-based relative attention""" - # if pdb_fn is not a full path, default to the path data/pdb_files/ - if dirname(pdb_fn) == "": - # handle the case where the pdb file is in the current working directory - # if there is a PDB file in the cwd.... then just use it as is. otherwise, append the default. - if not isfile(pdb_fn): - pdb_fn = join(self.default_pdb_dir, pdb_fn) +class ConvModel2(nn.Module): + """convolutional source model that supports padded inputs, pooling, etc""" - # create a structure graph from the pdb_fn and contact threshold - cbeta_mtx = cbeta_distance_matrix(pdb_fn) - structure_graph = dist_thresh_graph(cbeta_mtx, self.contact_threshold) + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--use_embedding", action="store_true", default=False) + parser.add_argument("--embedding_len", type=int, default=128) - # bucket_mtx indexes into the embedding lookup table to create the final distance matrix - bucket_mtx = self._compute_bucket_mtx(structure_graph) + parser.add_argument("--num_conv_layers", type=int, default=1) + parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) + parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) + parser.add_argument("--dilations", type=int, nargs="+", default=[1]) + parser.add_argument( + "--padding", type=str, default="valid", choices=["valid", "same"] + ) + parser.add_argument("--use_conv_layer_norm", action="store_true", default=False) + parser.add_argument( + "--conv_layer_norm_before_activation", action="store_true", default=False + ) + parser.add_argument( + "--use_conv_layer_dropout", action="store_true", default=False + ) + parser.add_argument("--conv_layer_dropout_rate", type=float, default=0.2) - self._set_bucket_mtx(pdb_fn, bucket_mtx) + parser.add_argument( + "--global_average_pooling", action="store_true", default=False + ) - def _compute_bucketed_neighbors(self, structure_graph, source_node): - """gets the bucketed neighbors from the given source node and structure graph""" - if self.clipping_threshold < 0: - raise ValueError("Clipping threshold must be >= 0") + parser.add_argument( + "--use_task_specific_layers", action="store_true", default=False + ) + parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) + parser.add_argument( + "--use_final_hidden_layer", action="store_true", default=False + ) + parser.add_argument("--final_hidden_size", type=int, default=64) + parser.add_argument( + "--use_final_hidden_layer_norm", action="store_true", default=False + ) + parser.add_argument( + "--final_hidden_layer_norm_before_activation", + action="store_true", + default=False, + ) + parser.add_argument( + "--use_final_hidden_layer_dropout", action="store_true", default=False + ) + parser.add_argument( + "--final_hidden_layer_dropout_rate", type=float, default=0.2 + ) - sspl = _inv_dict( - nx.single_source_shortest_path_length(structure_graph, source_node) + parser.add_argument( + "--activation", + type=str, + default="relu", + help="activation function used for all activations in the network", ) - if self.clipping_threshold is not None: - num_buckets = 1 + self.clipping_threshold - sspl = _combine_d(sspl, self.clipping_threshold, num_buckets - 1) + return parser - return sspl + def __init__( + self, + # data + num_tasks: int, + aa_seq_len: int, + aa_encoding_len: int, + num_tokens: int, + # convolutional model args + use_embedding: bool = False, + embedding_len: int = 64, + num_conv_layers: int = 1, + kernel_sizes: List[int] = (7,), + out_channels: List[int] = (128,), + dilations: List[int] = (1,), + padding: str = "valid", + use_conv_layer_norm: bool = False, + conv_layer_norm_before_activation: bool = False, + use_conv_layer_dropout: bool = False, + conv_layer_dropout_rate: float = 0.2, + # pooling + global_average_pooling: bool = True, + # prediction layers + use_task_specific_layers: bool = False, + task_specific_hidden_nodes: int = 64, + use_final_hidden_layer: bool = False, + final_hidden_size: int = 64, + use_final_hidden_layer_norm: bool = False, + final_hidden_layer_norm_before_activation: bool = False, + use_final_hidden_layer_dropout: bool = False, + final_hidden_layer_dropout_rate: float = 0.2, + # activation function + activation: str = "relu", + *args, + **kwargs, + ): - def _compute_bucket_mtx(self, structure_graph): - """get the bucket_mtx for the given structure_graph - calls _get_bucketed_neighbors for every node in the structure_graph""" - num_residues = len(list(structure_graph)) + super(ConvModel2, self).__init__() - # index into the embedding lookup table to create the final distance matrix - bucket_mtx = torch.zeros(num_residues, num_residues, dtype=torch.long) + # build up the layers + layers = collections.OrderedDict() - for node_num in sorted(list(structure_graph)): - bucketed_neighbors = self._compute_bucketed_neighbors( - structure_graph, node_num + # amino acid embedding + if use_embedding: + layers["embedder"] = ScaledEmbedding( + num_embeddings=num_tokens, embedding_dim=embedding_len, scale=False ) - for bucket_num, neighbors in bucketed_neighbors.items(): - bucket_mtx[node_num, neighbors] = bucket_num + # transpose the input to match PyTorch's expected format + layers["transpose"] = Transpose(dims=(1, 2)) - return bucket_mtx + # build up the convolutional layers + for layer_num in range(num_conv_layers): + # determine the number of input channels for the first convolutional layer + if layer_num == 0 and use_embedding: + # for the first convolutional layer, the in_channels is the embedding_len + in_channels = embedding_len + elif layer_num == 0 and not use_embedding: + # for the first convolutional layer, the in_channels is the aa_encoding_len + in_channels = aa_encoding_len + else: + in_channels = out_channels[layer_num - 1] + layers[f"conv{layer_num}"] = ConvBlock( + in_channels=in_channels, + out_channels=out_channels[layer_num], + kernel_size=kernel_sizes[layer_num], + dilation=dilations[layer_num], + padding=padding, + use_batchnorm=False, + use_layernorm=use_conv_layer_norm, + norm_before_activation=conv_layer_norm_before_activation, + use_dropout=use_conv_layer_dropout, + dropout_rate=conv_layer_dropout_rate, + activation=activation, + ) -class RelativePosition(nn.Module): - """creates the embedding lookup table E_r and computes R - note this inherits from pl.LightningModule instead of nn.Module - makes it easier to access the device with `self.device` - might be able to keep it as an nn.Module using the hacky dummy_param or commented out .device property - """ + # handle transition from convolutional layers to fully connected layer + # either use global average pooling or flatten + # take into consideration whether we are using valid or same padding + if global_average_pooling: + # global average pooling (mean across the seq len dimension) + # the seq len dimensions is the last dimension (batch_size, num_filters, seq_len) + layers["avg_pooling"] = GlobalAveragePooling(dim=-1) + # the prediction layers will take num_filters input features + pred_layer_input_features = out_channels[-1] - def __init__(self, embedding_len: int, clipping_threshold: int): - """ - embedding_len: the length of the embedding, may be d_model, or d_model // num_heads for multihead - clipping_threshold: the maximum relative position, referred to as k by Shaw et al. - """ - super().__init__() - self.embedding_len = embedding_len - self.clipping_threshold = clipping_threshold - # for sequence-based distances, the number of embeddings is 2*k+1, where k is the clipping threshold - num_embeddings = 2 * clipping_threshold + 1 + else: + # no global average pooling. flatten instead. + layers["flatten"] = nn.Flatten() + # calculate the final output len of the convolutional layers + # and the number of input features for the prediction layers + if padding == "valid": + # valid padding (aka no padding) results in shrinking length in progressive layers + conv_out_len = conv1d_out_shape( + aa_seq_len, kernel_size=kernel_sizes[0], dilation=dilations[0] + ) + for layer_num in range(1, num_conv_layers): + conv_out_len = conv1d_out_shape( + conv_out_len, + kernel_size=kernel_sizes[layer_num], + dilation=dilations[layer_num], + ) + pred_layer_input_features = conv_out_len * out_channels[-1] + else: + # padding == "same" + pred_layer_input_features = aa_seq_len * out_channels[-1] - # this is the embedding lookup table E_r - self.embeddings_table = nn.Embedding(num_embeddings, embedding_len) + # prediction layer + if use_task_specific_layers: + layers["prediction"] = TaskSpecificPredictionLayers( + num_tasks=num_tasks, + in_features=pred_layer_input_features, + num_hidden_nodes=task_specific_hidden_nodes, + activation=activation, + ) - # for getting the correct device for range vectors in forward - self.register_buffer("dummy_buffer", torch.empty(0), persistent=False) + # final hidden layer (with potential additional dropout) + elif use_final_hidden_layer: + layers["fc1"] = FCBlock( + in_features=pred_layer_input_features, + num_hidden_nodes=final_hidden_size, + use_batchnorm=False, + use_layernorm=use_final_hidden_layer_norm, + norm_before_activation=final_hidden_layer_norm_before_activation, + use_dropout=use_final_hidden_layer_dropout, + dropout_rate=final_hidden_layer_dropout_rate, + activation=activation, + ) + layers["prediction"] = nn.Linear( + in_features=final_hidden_size, out_features=num_tasks + ) - def forward(self, length_q, length_k): - # supports different length sequences, but in self-attention length_q and length_k are the same - range_vec_q = torch.arange(length_q, device=self.dummy_buffer.device) - range_vec_k = torch.arange(length_k, device=self.dummy_buffer.device) + else: + layers["prediction"] = nn.Linear( + in_features=pred_layer_input_features, out_features=num_tasks + ) - # this sets up the standard sequence-based distance matrix for relative positions - # the current position is 0, positions to the right are +1, +2, etc, and to the left -1, -2, etc - distance_mat = range_vec_k[None, :] - range_vec_q[:, None] - distance_mat_clipped = torch.clamp( - distance_mat, -self.clipping_threshold, self.clipping_threshold - ) + self.model = nn.Sequential(layers) - # convert to indices, indexing into the embedding table - final_mat = (distance_mat_clipped + self.clipping_threshold).long() + def forward(self, x, **kwargs): + output = self.model(x) + return output - # compute matrix R by grabbing the embeddings from the embedding lookup table - embeddings = self.embeddings_table(final_mat) - return embeddings +class ConvModel(nn.Module): + """a convolutional network with convolutional layers followed by a fully connected layer""" + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--num_conv_layers", type=int, default=1) + parser.add_argument("--kernel_sizes", type=int, nargs="+", default=[7]) + parser.add_argument("--out_channels", type=int, nargs="+", default=[128]) + parser.add_argument( + "--padding", type=str, default="valid", choices=["valid", "same"] + ) + parser.add_argument( + "--use_final_hidden_layer", + action="store_true", + help="whether to use a final hidden layer", + ) + parser.add_argument( + "--final_hidden_size", + type=int, + default=128, + help="number of nodes in the final hidden layer", + ) + parser.add_argument( + "--use_dropout", + action="store_true", + help="whether to use dropout in the final hidden layer", + ) + parser.add_argument( + "--dropout_rate", + type=float, + default=0.2, + help="dropout rate in the final hidden layer", + ) + parser.add_argument( + "--use_task_specific_layers", action="store_true", default=False + ) + parser.add_argument("--task_specific_hidden_nodes", type=int, default=64) + return parser -class RelativeMultiHeadAttention(nn.Module): def __init__( self, - embed_dim, - num_heads, - dropout, - pos_encoding, - clipping_threshold, - contact_threshold, - pdb_fns, + num_tasks: int, + aa_seq_len: int, + aa_encoding_len: int, + num_conv_layers: int = 1, + kernel_sizes: List[int] = (7,), + out_channels: List[int] = (128,), + padding: str = "valid", + use_final_hidden_layer: bool = True, + final_hidden_size: int = 128, + use_dropout: bool = False, + dropout_rate: float = 0.2, + use_task_specific_layers: bool = False, + task_specific_hidden_nodes: int = 64, + *args, + **kwargs, ): - """ - Multi-head attention with relative position embeddings. Input data should be in batch_first format. - :param embed_dim: aka d_model, aka hid_dim - :param num_heads: number of heads - :param dropout: how much dropout for scaled dot product attention - - :param pos_encoding: what type of positional encoding to use, relative or relative3D - :param clipping_threshold: clipping threshold for relative position embedding - :param contact_threshold: for relative_3D, the threshold in angstroms for the contact map - :param pdb_fns: pdb file(s) to set up the relative position object - - """ - super().__init__() - assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - - # model dimensions - self.embed_dim = embed_dim - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - - # pos encoding stuff - self.pos_encoding = pos_encoding - self.clipping_threshold = clipping_threshold - self.contact_threshold = contact_threshold - if pdb_fns is not None and not isinstance(pdb_fns, list): - pdb_fns = [pdb_fns] - self.pdb_fns = pdb_fns - - # relative position embeddings for use with keys and values - # Shaw et al. uses relative position information for both keys and values - # Huang et al. only uses it for the keys, which is probably enough - if pos_encoding == "relative": - self.relative_position_k = RelativePosition( - self.head_dim, self.clipping_threshold - ) - self.relative_position_v = RelativePosition( - self.head_dim, self.clipping_threshold - ) - elif pos_encoding == "relative_3D": - self.relative_position_k = RelativePosition3D( - self.head_dim, - self.contact_threshold, - self.clipping_threshold, - self.pdb_fns, + super(ConvModel, self).__init__() + + # set up the model as a Sequential block (less to do in forward()) + layers = collections.OrderedDict() + + layers["transpose"] = Transpose(dims=(1, 2)) + + for layer_num in range(num_conv_layers): + # for the first convolutional layer, the in_channels is the feature_len + in_channels = ( + aa_encoding_len if layer_num == 0 else out_channels[layer_num - 1] ) - self.relative_position_v = RelativePosition3D( - self.head_dim, - self.contact_threshold, - self.clipping_threshold, - self.pdb_fns, + + layers["conv{}".format(layer_num)] = nn.Sequential( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels[layer_num], + kernel_size=kernel_sizes[layer_num], + padding=padding, + ), + nn.ReLU(), ) - else: - raise ValueError("unrecognized pos_encoding: {}".format(pos_encoding)) - # WQ, WK, and WV from attention is all you need - # note these default to bias=True, same as PyTorch implementation - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) + layers["flatten"] = nn.Flatten() - # WO from attention is all you need - # used for the final projection when computing multi-head attention - # PyTorch uses NonDynamicallyQuantizableLinear instead of Linear to avoid triggering an obscure - # error quantizing the model https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L122 - # todo: if quantizing the model, explore if the above is a concern for us - self.out_proj = nn.Linear(embed_dim, embed_dim) + # calculate the final output len of the convolutional layers + # and the number of input features for the prediction layers + if padding == "valid": + # valid padding (aka no padding) results in shrinking length in progressive layers + conv_out_len = conv1d_out_shape(aa_seq_len, kernel_size=kernel_sizes[0]) + for layer_num in range(1, num_conv_layers): + conv_out_len = conv1d_out_shape( + conv_out_len, kernel_size=kernel_sizes[layer_num] + ) + next_dim = conv_out_len * out_channels[-1] + elif padding == "same": + next_dim = aa_seq_len * out_channels[-1] + else: + raise ValueError("unexpected value for padding: {}".format(padding)) - # dropout for scaled dot product attention - self.dropout = nn.Dropout(dropout) + # final hidden layer (with potential additional dropout) + if use_final_hidden_layer: + layers["fc1"] = FCBlock( + in_features=next_dim, + num_hidden_nodes=final_hidden_size, + use_batchnorm=False, + use_dropout=use_dropout, + dropout_rate=dropout_rate, + ) + next_dim = final_hidden_size - # scaling factor for scaled dot product attention - scale = torch.sqrt(torch.FloatTensor([self.head_dim])) - # persistent=False if you don't want to save it inside state_dict - self.register_buffer("scale", scale) + # final prediction layer + # either task specific nonlinear layers or a single linear layer + if use_task_specific_layers: + layers["prediction"] = TaskSpecificPredictionLayers( + num_tasks=num_tasks, + in_features=next_dim, + num_hidden_nodes=task_specific_hidden_nodes, + ) + else: + layers["prediction"] = nn.Linear( + in_features=next_dim, out_features=num_tasks + ) - # toggles meant to be set directly by user - self.need_weights = False - self.average_attn_weights = True + self.model = nn.Sequential(layers) - def _compute_attn_weights(self, query, key, len_q, len_k, batch_size, mask, pdb_fn): - """computes the attention weights (a "compatability function" of queries with corresponding keys)""" + def forward(self, x, **kwargs): + output = self.model(x) + return output - # calculate the first term in the numerator attn1, which is Q*K - # todo: pytorch reshapes q,k and v to 3 dimensions (similar to how r_q2 is below) - # is that functionally equivalent to what we're doing? is their way faster? - # r_q1 = [batch_size, num_heads, len_q, head_dim] - r_q1 = query.view(batch_size, len_q, self.num_heads, self.head_dim).permute( - 0, 2, 1, 3 - ) - # todo: we could directly permute r_k1 to [batch_size, num_heads, head_dim, len_k] - # to make it compatible for matrix multiplication with r_q1, instead of 2-step approach - # r_k1 = [batch_size, num_heads, len_k, head_dim] - r_k1 = key.view(batch_size, len_k, self.num_heads, self.head_dim).permute( - 0, 2, 1, 3 - ) - # attn1 = [batch_size, num_heads, len_q, len_k] - attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) - # calculate the second term in the numerator attn2, which is Q*R - # r_q2 = [query_len, batch_size * num_heads, head_dim] - r_q2 = ( - query.permute(1, 0, 2) - .contiguous() - .view(len_q, batch_size * self.num_heads, self.head_dim) +class FCModel(nn.Module): + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--num_layers", type=int, default=1) + parser.add_argument("--num_hidden", nargs="+", type=int, default=[128]) + parser.add_argument("--use_batchnorm", action="store_true", default=False) + parser.add_argument("--use_layernorm", action="store_true", default=False) + parser.add_argument( + "--norm_before_activation", action="store_true", default=False ) + parser.add_argument("--use_dropout", action="store_true", default=False) + parser.add_argument("--dropout_rate", type=float, default=0.2) + return parser - # todo: support multiple different PDB base structures per batch - # one option: - # - require batches to be all the same protein - # - add argument to forward() to accept the PDB file for the protein in the batch - # - then we just pass in the PDB file to relative position's forward() - # to support multiple different structures per batch: - # - add argument to forward() to accept PDB files, one for each item in batch - # - make corresponding changing in relative_position object to return R for each structure - # - note: if there are a lot of of different structures, and the sequence lengths are long, - # this could be memory prohibitive because R (rel_pos_k) can take up a lot of mem for long seqs - # - adjust the attn2 calculation to factor in the multiple different R matrices. - # the way to do this might have to be to do multiple matmuls, one for each each - # basically, would split up r_q2 into several matrices grouped by structure, and then - # multiply with corresponding R, then combine back into the exact same order of the original r_q2 - # note: this may be computationally intensive (splitting, more matrix muliplies, joining) - # another option would be to create views(?), repeating the different Rs so we can do a - # a matris multiply directly with r_q2 - # - would shapes be affected if there was padding in the queries, keys, values? + def __init__( + self, + num_tasks: int, + seq_encoding_len: int, + num_layers: int = 1, + num_hidden: List[int] = (128,), + use_batchnorm: bool = False, + use_layernorm: bool = False, + norm_before_activation: bool = False, + use_dropout: bool = False, + dropout_rate: float = 0.2, + activation: str = "relu", + *args, + **kwargs, + ): + super().__init__() - if self.pos_encoding == "relative": - # rel_pos_k = [len_q, len_k, head_dim] - rel_pos_k = self.relative_position_k(len_q, len_k) - elif self.pos_encoding == "relative_3D": - # rel_pos_k = [sequence length (from PDB structure), head_dim] - rel_pos_k = self.relative_position_k(pdb_fn) - else: - raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) + # set up the model as a Sequential block (less to do in forward()) + layers = collections.OrderedDict() - # the matmul basically computes the dot product between each input position’s query vector and - # its corresponding relative position embeddings across all input sequences in the heads and batch - # attn2 = [batch_size * num_heads, len_q, len_k] - attn2 = torch.matmul(r_q2, rel_pos_k.transpose(1, 2)).transpose(0, 1) - # attn2 = [batch_size, num_heads, len_q, len_k] - attn2 = attn2.contiguous().view(batch_size, self.num_heads, len_q, len_k) + # flatten inputs as this is all fully connected + layers["flatten"] = nn.Flatten() - # calculate attention weights - attn_weights = (attn1 + attn2) / self.scale + # build up the variable number of hidden layers (fully connected + ReLU + dropout (if set)) + for layer_num in range(num_layers): + # for the first layer (layer_num == 0), in_features is determined by given input + # for subsequent layers, the in_features is the previous layer's num_hidden + in_features = ( + seq_encoding_len if layer_num == 0 else num_hidden[layer_num - 1] + ) - # apply mask if given - if mask is not None: - # todo: pytorch uses float("-inf") instead of -1e10 - attn_weights = attn_weights.masked_fill(mask == 0, -1e10) + layers["fc{}".format(layer_num)] = FCBlock( + in_features=in_features, + num_hidden_nodes=num_hidden[layer_num], + use_batchnorm=use_batchnorm, + use_layernorm=use_layernorm, + norm_before_activation=norm_before_activation, + use_dropout=use_dropout, + dropout_rate=dropout_rate, + activation=activation, + ) - # softmax gives us attn_weights weights - attn_weights = torch.softmax(attn_weights, dim=-1) - # attn_weights = [batch_size, num_heads, len_q, len_k] - attn_weights = self.dropout(attn_weights) + # finally, the linear output layer + in_features = num_hidden[-1] if num_layers > 0 else seq_encoding_len + layers["output"] = nn.Linear(in_features=in_features, out_features=num_tasks) - return attn_weights + self.model = nn.Sequential(layers) - def _compute_avg_val( - self, value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn - ): - # todo: add option to not factor in relative position embeddings in value calculation - # calculate the first term, the attn*values - # r_v1 = [batch_size, num_heads, len_v, head_dim] - r_v1 = value.view(batch_size, len_v, self.num_heads, self.head_dim).permute( - 0, 2, 1, 3 - ) - # avg1 = [batch_size, num_heads, len_q, head_dim] - avg1 = torch.matmul(attn_weights, r_v1) + def forward(self, x, **kwargs): + output = self.model(x) + return output - # calculate the second term, the attn*R - # similar to how relative embeddings are factored in the attention weights calculation - if self.pos_encoding == "relative": - # rel_pos_v = [query_len, value_len, head_dim] - rel_pos_v = self.relative_position_v(len_q, len_v) - elif self.pos_encoding == "relative_3D": - # rel_pos_v = [sequence length (from PDB structure), head_dim] - rel_pos_v = self.relative_position_v(pdb_fn) - else: - raise ValueError("unrecognized pos_encoding: {}".format(self.pos_encoding)) - # r_attn_weights = [len_q, batch_size * num_heads, len_v] - r_attn_weights = ( - attn_weights.permute(2, 0, 1, 3) - .contiguous() - .view(len_q, batch_size * self.num_heads, len_k) - ) - avg2 = torch.matmul(r_attn_weights, rel_pos_v) - # avg2 = [batch_size, num_heads, len_q, head_dim] - avg2 = ( - avg2.transpose(0, 1) - .contiguous() - .view(batch_size, self.num_heads, len_q, self.head_dim) - ) +class LRModel(nn.Module): + """a simple linear model""" - # calculate avg value - x = avg1 + avg2 # [batch_size, num_heads, len_q, head_dim] - x = x.permute( - 0, 2, 1, 3 - ).contiguous() # [batch_size, len_q, num_heads, head_dim] - # x = [batch_size, len_q, embed_dim] - x = x.view(batch_size, len_q, self.embed_dim) + def __init__(self, num_tasks, seq_encoding_len, *args, **kwargs): + super().__init__() - return x + self.model = nn.Sequential( + nn.Flatten(), nn.Linear(seq_encoding_len, out_features=num_tasks) + ) - def forward(self, query, key, value, pdb_fn=None, mask=None): - # query = [batch_size, q_len, embed_dim] - # key = [batch_size, k_len, embed_dim] - # value = [batch_size, v_en, embed_dim] - batch_size = query.shape[0] - len_k, len_q, len_v = (key.shape[1], query.shape[1], value.shape[1]) + def forward(self, x, **kwargs): + output = self.model(x) + return output - # in projection (multiply inputs by WQ, WK, WV) - query = self.q_proj(query) - key = self.k_proj(key) - value = self.v_proj(value) - # first compute the attention weights, then multiply with values - # attn = [batch size, num_heads, len_q, len_k] - attn_weights = self._compute_attn_weights( - query, key, len_q, len_k, batch_size, mask, pdb_fn - ) +class TransferModel(nn.Module): + """transfer learning model""" - # take weighted average of values (weighted by attention weights) - attn_output = self._compute_avg_val( - value, len_q, len_k, len_v, attn_weights, batch_size, pdb_fn - ) + @staticmethod + def add_model_specific_args(parent_parser): - # output projection - # attn_output = [batch_size, len_q, embed_dim] - attn_output = self.out_proj(attn_output) + def none_or_int(value: str): + return None if value.lower() == "none" else int(value) - if self.need_weights: - # return attention weights in addition to attention - # average the weights over the heads (to get overall attention) - # attn_weights = [batch_size, len_q, len_k] - if self.average_attn_weights: - attn_weights = attn_weights.sum(dim=1) / self.num_heads - return {"attn_output": attn_output, "attn_weights": attn_weights} - else: - return attn_output + p = ArgumentParser(parents=[parent_parser], add_help=False) + # for model set up + p.add_argument("--pretrained_ckpt_path", type=str, default=None) -class RelativeTransformerEncoderLayer(nn.Module): - """ - d_model: the number of expected features in the input (required). - nhead: the number of heads in the MultiHeadAttention models (required). - clipping_threshold: the clipping threshold for relative position embeddings - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of the intermediate layer, can be a string - ("relu" or "gelu") or a unary callable. Default: relu - layer_norm_eps: the eps value in layer normalization components (default=1e-5). - norm_first: if ``True``, layer norm is done prior to attention and feedforward - operations, respectively. Otherwise, it's done after. Default: ``False`` (after). - """ + # where to cut off the backbone + p.add_argument( + "--backbone_cutoff", + type=none_or_int, + default=-1, + help="where to cut off the backbone. can be a negative int, indexing back from " + "pretrained_model.model.model. a value of -1 would chop off the backbone prediction head. " + "a value of -2 chops the prediction head and FC layer. a value of -3 chops" + "the above, as well as the global average pooling layer. all depends on architecture.", + ) - # this is some kind of torch jit compiling helper... will also ensure these values don't change - __constants__ = ["batch_first", "norm_first"] + p.add_argument( + "--pred_layer_input_features", + type=int, + default=None, + help="if None, number of features will be determined based on backbone_cutoff and standard " + "architecture. otherwise, specify the number of input features for the prediction layer", + ) + + # top net args + p.add_argument( + "--top_net_type", + type=str, + default="linear", + choices=["linear", "nonlinear", "sklearn"], + ) + p.add_argument("--top_net_hidden_nodes", type=int, default=256) + p.add_argument("--top_net_use_batchnorm", action="store_true") + p.add_argument("--top_net_use_dropout", action="store_true") + p.add_argument("--top_net_dropout_rate", type=float, default=0.1) + + return p def __init__( self, - d_model, - nhead, - pos_encoding="relative", - clipping_threshold=3, - contact_threshold=7, - pdb_fns=None, - dim_feedforward=2048, - dropout=0.1, - activation=F.relu, - layer_norm_eps=1e-5, - norm_first=False, - ) -> None: + # pretrained model + pretrained_ckpt_path: Optional[str] = None, + pretrained_hparams: Optional[dict] = None, + backbone_cutoff: Optional[int] = -1, + # top net + pred_layer_input_features: Optional[int] = None, + top_net_type: str = "linear", + top_net_hidden_nodes: int = 256, + top_net_use_batchnorm: bool = False, + top_net_use_dropout: bool = False, + top_net_dropout_rate: float = 0.1, + *args, + **kwargs, + ): - self.batch_first = True + super().__init__() - super(RelativeTransformerEncoderLayer, self).__init__() + # error checking: if pretrained_ckpt_path is None, then pretrained_hparams must be specified + if pretrained_ckpt_path is None and pretrained_hparams is None: + raise ValueError( + "Either pretrained_ckpt_path or pretrained_hparams must be specified" + ) - self.self_attn = RelativeMultiHeadAttention( - d_model, - nhead, - dropout, - pos_encoding, - clipping_threshold, - contact_threshold, - pdb_fns, - ) + # note: pdb_fns is loaded from transfer model arguments rather than original source model hparams + # if pdb_fns is specified as a kwarg, pass it on for structure-based RPE + # otherwise, can just set pdb_fns to None, and structure-based RPE will handle new PDBs on the fly + pdb_fns = kwargs["pdb_fns"] if "pdb_fns" in kwargs else None - # feed forward model - self.linear1 = Linear(d_model, dim_feedforward) - self.dropout = Dropout(dropout) - self.linear2 = Linear(dim_feedforward, d_model) + # generate a fresh backbone using pretrained_hparams if specified + # otherwise load the backbone from the pretrained checkpoint + # we prioritize pretrained_hparams over pretrained_ckpt_path because + # pretrained_hparams will only really be specified if we are loading from a DMSTask checkpoint + # meaning the TransferModel has already been fine-tuned on DMS data, and we are likely loading + # weights from that finetuning (including weights for the backbone) + # whereas if pretrained_hparams is not specified but pretrained_ckpt_path is, then we are + # likely finetuning the TransferModel for the first time, and we need the pretrained weights for the + # backbone from the RosettaTask checkpoint + if pretrained_hparams is not None: + # pretrained_hparams will only be specified if we are loading from a DMSTask checkpoint + pretrained_hparams["pdb_fns"] = pdb_fns + pretrained_model = Model[pretrained_hparams["model_name"]].cls( + **pretrained_hparams + ) + self.pretrained_hparams = pretrained_hparams + else: + # not supported in metl-pretrained + raise NotImplementedError( + "Loading pretrained weights from RosettaTask checkpoint not supported" + ) - self.norm_first = norm_first - self.norm1 = LayerNorm(d_model, eps=layer_norm_eps) - self.norm2 = LayerNorm(d_model, eps=layer_norm_eps) - self.dropout1 = Dropout(dropout) - self.dropout2 = Dropout(dropout) + layers = collections.OrderedDict() - # Legacy string support for activation function. - if isinstance(activation, str): - self.activation = get_activation_fn(activation) + # set the backbone to all layers except the last layer (the pre-trained prediction layer) + if backbone_cutoff is None: + layers["backbone"] = SequentialWithArgs( + *list(pretrained_model.model.children()) + ) else: - self.activation = activation + layers["backbone"] = SequentialWithArgs( + *list(pretrained_model.model.children())[0:backbone_cutoff] + ) - def forward(self, src: Tensor, pdb_fn=None) -> Tensor: - x = src - if self.norm_first: - x = x + self._sa_block(self.norm1(x), pdb_fn=pdb_fn) - x = x + self._ff_block(self.norm2(x)) - else: - x = self.norm1(x + self._sa_block(x)) - x = self.norm2(x + self._ff_block(x)) + if top_net_type == "sklearn": + # sklearn top not doesn't require any more layers, just return model for the repr layer + self.model = SequentialWithArgs(layers) + return - return x + # figure out dimensions of input into the prediction layer + if pred_layer_input_features is None: + # todo: can make this more robust by checking if the pretrained_mode.hparams for use_final_hidden_layer, + # global_average_pooling, etc. then can determine what the layer will be based on backbone_cutoff. + # currently, assumes that pretrained_model uses global average pooling and a final_hidden_layer + if backbone_cutoff is None: + # no backbone cutoff... use the full network (including tasks) as the backbone + pred_layer_input_features = self.pretrained_hparams["num_tasks"] + elif backbone_cutoff == -1: + pred_layer_input_features = self.pretrained_hparams["final_hidden_size"] + elif backbone_cutoff == -2: + pred_layer_input_features = self.pretrained_hparams["embedding_len"] + elif backbone_cutoff == -3: + pred_layer_input_features = ( + self.pretrained_hparams["embedding_len"] * kwargs["aa_seq_len"] + ) + else: + raise ValueError( + "can't automatically determine pred_layer_input_features for given backbone_cutoff" + ) - # self-attention block - def _sa_block(self, x: Tensor, pdb_fn=None) -> Tensor: - x = self.self_attn(x, x, x, pdb_fn=pdb_fn) - if isinstance(x, dict): - # handle the case where we are returning attention weights - x = x["attn_output"] - return self.dropout1(x) + layers["flatten"] = nn.Flatten(start_dim=1) - # feed forward block - def _ff_block(self, x: Tensor) -> Tensor: - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - return self.dropout2(x) + # create a new prediction layer on top of the backbone + if top_net_type == "linear": + # linear layer for prediction + layers["prediction"] = nn.Linear( + in_features=pred_layer_input_features, out_features=1 + ) + elif top_net_type == "nonlinear": + # fully connected with hidden layer + fc_block = FCBlock( + in_features=pred_layer_input_features, + num_hidden_nodes=top_net_hidden_nodes, + use_batchnorm=top_net_use_batchnorm, + use_dropout=top_net_use_dropout, + dropout_rate=top_net_dropout_rate, + ) + + pred_layer = nn.Linear(in_features=top_net_hidden_nodes, out_features=1) + + layers["prediction"] = SequentialWithArgs(fc_block, pred_layer) + else: + raise ValueError( + "Unexpected type of top net layer: {}".format(top_net_type) + ) + + self.model = SequentialWithArgs(layers) + + def forward(self, x, **kwargs): + return self.model(x, **kwargs) -class RelativeTransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None, reset_params=True): - super(RelativeTransformerEncoder, self).__init__() - # using get_clones means all layers have the same initialization - # this is also a problem in PyTorch's TransformerEncoder implementation, which this is based on - # todo: PyTorch is changing its transformer API... check up on and see if there is a better way - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm +def get_activation_fn(activation, functional=True): + if activation == "relu": + return F.relu if functional else nn.ReLU() + elif activation == "gelu": + return F.gelu if functional else nn.GELU() + elif activation == "silo" or activation == "swish": + return F.silu if functional else nn.SiLU() + elif activation == "leaky_relu" or activation == "lrelu": + return F.leaky_relu if functional else nn.LeakyReLU() + else: + raise RuntimeError("unknown activation: {}".format(activation)) - # important because get_clones means all layers have same initialization - # should recursively reset parameters for all submodules - if reset_params: - self.apply(reset_parameters_helper) - def forward(self, src: Tensor, pdb_fn=None) -> Tensor: - output = src +class Model(enum.Enum): + def __new__(cls, *args, **kwds): + value = len(cls.__members__) + 1 + obj = object.__new__(cls) + obj._value_ = value + return obj - for mod in self.layers: - output = mod(output, pdb_fn=pdb_fn) + def __init__(self, cls, transfer_model): + self.cls = cls + self.transfer_model = transfer_model - if self.norm is not None: - output = self.norm(output) + linear = LRModel, False + fully_connected = FCModel, False + cnn = ConvModel, False + cnn2 = ConvModel2, False + transformer_encoder = AttnModel, False + transfer_model = TransferModel, True - return output +def main(): + pass -def _get_clones(module, num_clones): - return nn.ModuleList([copy.deepcopy(module) for _ in range(num_clones)]) +if __name__ == "__main__": + main() -def _inv_dict(d): - """helper function for contact map-based position embeddings""" - inv = dict() - for k, v in d.items(): - # collect dict keys into lists based on value - inv.setdefault(v, list()).append(k) - for k, v in inv.items(): - # put in sorted order - inv[k] = sorted(v) - return inv +UUID_URL_MAP = { + # global source models + "D72M9aEp": "https://zenodo.org/records/14908509/files/METL-G-20M-1D-D72M9aEp.pt?download=1", + "Nr9zCKpR": "https://zenodo.org/records/14908509/files/METL-G-20M-3D-Nr9zCKpR.pt?download=1", + "auKdzzwX": "https://zenodo.org/records/14908509/files/METL-G-50M-1D-auKdzzwX.pt?download=1", + "6PSAzdfv": "https://zenodo.org/records/14908509/files/METL-G-50M-3D-6PSAzdfv.pt?download=1", + # local source models + "8gMPQJy4": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GFP-8gMPQJy4.pt?download=1", + "Hr4GNHws": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GFP-Hr4GNHws.pt?download=1", + "8iFoiYw2": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-DLG4_2022-8iFoiYw2.pt?download=1", + "kt5DdWTa": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-DLG4_2022-kt5DdWTa.pt?download=1", + "DMfkjVzT": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GB1-DMfkjVzT.pt?download=1", + "epegcFiH": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GB1-epegcFiH.pt?download=1", + "kS3rUS7h": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-GRB2-kS3rUS7h.pt?download=1", + "X7w83g6S": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-GRB2-X7w83g6S.pt?download=1", + "UKebCQGz": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Pab1-UKebCQGz.pt?download=1", + "2rr8V4th": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Pab1-2rr8V4th.pt?download=1", + "PREhfC22": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-TEM-1-PREhfC22.pt?download=1", + "9ASvszux": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-TEM-1-9ASvszux.pt?download=1", + "HscFFkAb": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-Ube4b-HscFFkAb.pt?download=1", + "H48oiNZN": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-Ube4b-H48oiNZN.pt?download=1", + "CEMSx7ZC": "https://zenodo.org/records/14908509/files/METL-L-2M-1D-PTEN-CEMSx7ZC.pt?download=1", + "PjxR5LW7": "https://zenodo.org/records/14908509/files/METL-L-2M-3D-PTEN-PjxR5LW7.pt?download=1", + # metl bind source models + "K6mw24Rg": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-STANDARD-K6mw24Rg.pt?download=1", + "Bo5wn2SG": "https://zenodo.org/records/14908509/files/METL-BIND-2M-3D-GB1-BINDING-Bo5wn2SG.pt?download=1", + # finetuned models from GFP design experiment + "YoQkzoLD": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-1D-GFP-YoQkzoLD.pt?download=1", + "PEkeRuxb": "https://zenodo.org/records/14908509/files/FT-METL-L-2M-3D-GFP-PEkeRuxb.pt?download=1", +} -def _combine_d(d, threshold, combined_key): - """helper function for contact map-based position embeddings - d is a dictionary with ints as keys and lists as values. - for all keys >= threshold, this function combines the values of those keys into a single list - """ - out_d = {} - for k, v in d.items(): - if k < threshold: - out_d[k] = v - elif k >= threshold: - if combined_key not in out_d: - out_d[combined_key] = v - else: - out_d[combined_key] += v - if combined_key in out_d: - out_d[combined_key] = sorted(out_d[combined_key]) - return out_d +IDENT_UUID_MAP = { + # the keys should be all lowercase + "metl-g-20m-1d": "D72M9aEp", + "metl-g-20m-3d": "Nr9zCKpR", + "metl-g-50m-1d": "auKdzzwX", + "metl-g-50m-3d": "6PSAzdfv", + # GFP local source models + "metl-l-2m-1d-gfp": "8gMPQJy4", + "metl-l-2m-3d-gfp": "Hr4GNHws", + # DLG4 local source models + "metl-l-2m-1d-dlg4_2022": "8iFoiYw2", + "metl-l-2m-3d-dlg4_2022": "kt5DdWTa", + # GB1 local source models + "metl-l-2m-1d-gb1": "DMfkjVzT", + "metl-l-2m-3d-gb1": "epegcFiH", + # GRB2 local source models + "metl-l-2m-1d-grb2": "kS3rUS7h", + "metl-l-2m-3d-grb2": "X7w83g6S", + # Pab1 local source models + "metl-l-2m-1d-pab1": "UKebCQGz", + "metl-l-2m-3d-pab1": "2rr8V4th", + # PTEN local source models + "metl-l-2m-1d-pten": "CEMSx7ZC", + "metl-l-2m-3d-pten": "PjxR5LW7", + # TEM-1 local source models + "metl-l-2m-1d-tem-1": "PREhfC22", + "metl-l-2m-3d-tem-1": "9ASvszux", + # Ube4b local source models + "metl-l-2m-1d-ube4b": "HscFFkAb", + "metl-l-2m-3d-ube4b": "H48oiNZN", + # METL-Bind for GB1 + "metl-bind-2m-3d-gb1-standard": "K6mw24Rg", + "metl-bind-2m-3d-gb1-binding": "Bo5wn2SG", + # GFP design models, giving them an ident + "metl-l-2m-1d-gfp-ft-design": "YoQkzoLD", + "metl-l-2m-3d-gfp-ft-design": "PEkeRuxb", +} -""" Encodes data in different formats """ +def download_checkpoint(uuid): + ckpt = torch.hub.load_state_dict_from_url( + UUID_URL_MAP[uuid], map_location="cpu", file_name=f"{uuid}.pt" + ) + state_dict = ckpt["state_dict"] + hyper_parameters = ckpt["hyper_parameters"] + return state_dict, hyper_parameters -class Encoding(Enum): - INT_SEQS = auto() - ONE_HOT = auto() +def _get_data_encoding(hparams): + if "encoding" in hparams and hparams["encoding"] == "int_seqs": + encoding = Encoding.INT_SEQS + elif "encoding" in hparams and hparams["encoding"] == "one_hot": + encoding = Encoding.ONE_HOT + elif ( + ("encoding" in hparams and hparams["encoding"] == "auto") + or "encoding" not in hparams + ) and hparams["model_name"] in ["transformer_encoder"]: + encoding = Encoding.INT_SEQS + else: + raise ValueError("Detected unsupported encoding in hyperparameters") -class DataEncoder: - chars = [ - "*", - "A", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "K", - "L", - "M", - "N", - "P", - "Q", - "R", - "S", - "T", - "V", - "W", - "Y", - ] - num_chars = len(chars) - mapping = {c: i for i, c in enumerate(chars)} + return encoding - def __init__(self, encoding: Encoding = Encoding.INT_SEQS): - self.encoding = encoding - def _encode_from_int_seqs(self, seq_ints): - if self.encoding == Encoding.INT_SEQS: - return seq_ints - elif self.encoding == Encoding.ONE_HOT: - one_hot = np.eye(self.num_chars)[seq_ints] - return one_hot.astype(np.float32) +def load_model_and_data_encoder(state_dict, hparams): + model = Model[hparams["model_name"]].cls(**hparams) + model.load_state_dict(state_dict) - def encode_sequences(self, char_seqs): - seq_ints = [] - for char_seq in char_seqs: - int_seq = [self.mapping[c] for c in char_seq] - seq_ints.append(int_seq) - seq_ints = np.array(seq_ints).astype(int) - return self._encode_from_int_seqs(seq_ints) + data_encoder = DataEncoder(_get_data_encoding(hparams)) - def encode_variants(self, wt, variants): - # convert wild type seq to integer encoding - wt_int = np.zeros(len(wt), dtype=np.uint8) - for i, c in enumerate(wt): - wt_int[i] = self.mapping[c] + return model, data_encoder - # tile the wild-type seq - seq_ints = np.tile(wt_int, (len(variants), 1)) - for i, variant in enumerate(variants): - # special handling if we want to encode the wild-type seq (it's already correct!) - if variant == "_wt": - continue +def get_from_uuid(uuid): + if uuid in UUID_URL_MAP: + state_dict, hparams = download_checkpoint(uuid) + return load_model_and_data_encoder(state_dict, hparams) + else: + raise ValueError(f"UUID {uuid} not found in UUID_URL_MAP") - # variants are a list of mutations [mutation1, mutation2, ....] - variant = variant.split(",") - for mutation in variant: - # mutations are in the form - position = int(mutation[1:-1]) - replacement = self.mapping[mutation[-1]] - seq_ints[i, position] = replacement - seq_ints = seq_ints.astype(int) - return self._encode_from_int_seqs(seq_ints) +def get_from_ident(ident): + ident = ident.lower() + if ident in IDENT_UUID_MAP: + state_dict, hparams = download_checkpoint(IDENT_UUID_MAP[ident]) + return load_model_and_data_encoder(state_dict, hparams) + else: + raise ValueError(f"Identifier {ident} not found in IDENT_UUID_MAP") + + +def get_from_checkpoint(ckpt_fn): + ckpt = torch.load(ckpt_fn, map_location="cpu") + state_dict = ckpt["state_dict"] + hyper_parameters = ckpt["hyper_parameters"] + return load_model_and_data_encoder(state_dict, hyper_parameters) class GraphType(Enum):