File size: 5,184 Bytes
7ebbadf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from .encoderblock import TransformerEncoder, get_activation_fn
class AbLang(torch.nn.Module):
"""
AbLang inspired by ESM-2's architecture.
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
dropout: float = 0.0,
):
super().__init__()
self.AbRep = AbRep(
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps,
a_fn,
dropout,
)
self.AbHead = AbHead(
vocab_size,
hidden_embed_size,
self.AbRep.aa_embed_layer.weight,
layer_norm_eps,
a_fn,
)
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
if return_attn_weights:
return representations.attention_weights
elif return_rep_layers != []:
return representations.many_hidden_states
else:
likelihoods = self.AbHead(representations.last_hidden_states)
return likelihoods
def get_aa_embeddings(self):
"Extracts the trained aa_embeddings."
return self.AbRep.aa_embed_layer
class AbRep(torch.nn.Module):
"""
AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations).
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
dropout: float = 0.1,
):
super().__init__()
self.padding_tkn = padding_tkn
self.mask_tkn = mask_tkn
self.aa_embed_layer = nn.Embedding(
vocab_size,
hidden_embed_size,
padding_idx=padding_tkn,
)
self.encoder_blocks = nn.ModuleList(
[TransformerEncoder(
hidden_embed_size,
n_attn_heads,
attn_dropout = dropout,
layer_norm_eps = layer_norm_eps,
a_fn = a_fn,
) for _ in range(n_encoder_blocks)]
)
self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
def forward(self,
tokens,
return_attn_weights=False,
return_rep_layers=[],
):
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_tkn)
hidden_embed = self.aa_embed_layer(tokens)
return_rep_layers = set(return_rep_layers)
rep_layers = {}
if 0 in return_rep_layers: rep_layers[0] = hidden_embed
all_attn_weights = []
for n_layer, encoder_block in enumerate(self.encoder_blocks):
hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights)
if (n_layer + 1) in return_rep_layers:
rep_layers[n_layer + 1] = hidden_embed
if return_attn_weights:
all_attn_weights.append(attn_weights)
hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed)
return DataAbRep(
last_hidden_states=hidden_embed,
many_hidden_states=rep_layers,
attention_weights=all_attn_weights
)
class AbHead(torch.nn.Module):
"""
AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations).
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
weights,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
):
super().__init__()
activation_fn, scale = get_activation_fn(a_fn)
self.ff = torch.nn.Sequential(
nn.Linear(hidden_embed_size, hidden_embed_size * scale),
activation_fn(),
nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps),
)
self.weights = weights
self.bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, hidden_embed):
hidden_embed = self.ff(hidden_embed)
logits = F.linear(hidden_embed, self.weights) + self.bias
return logits
@dataclass
class DataAbRep():
"""
Dataclass used to store AbRep output.
"""
last_hidden_states: torch.FloatTensor
many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attention_weights: Optional[Tuple[torch.FloatTensor]] = None |