aaronkollasch
/

ablang2 / ablang.py
aaronkollasch's picture
Duplicate from hemantn/ablang2
e766090
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from .encoderblock import TransformerEncoder, get_activation_fn
class AbLang(torch.nn.Module):
"""
AbLang inspired by ESM-2's architecture.
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
dropout: float = 0.0,
):
super().__init__()
self.AbRep = AbRep(
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps,
a_fn,
dropout,
)
self.AbHead = AbHead(
vocab_size,
hidden_embed_size,
self.AbRep.aa_embed_layer.weight,
layer_norm_eps,
a_fn,
)
def forward(self, tokens, return_attn_weights=False, return_rep_layers=[]):
representations = self.AbRep(tokens, return_attn_weights, return_rep_layers)
if return_attn_weights:
return representations.attention_weights
elif return_rep_layers != []:
return representations.many_hidden_states
else:
likelihoods = self.AbHead(representations.last_hidden_states)
return likelihoods
def get_aa_embeddings(self):
"Extracts the trained aa_embeddings."
return self.AbRep.aa_embed_layer
class AbRep(torch.nn.Module):
"""
AbRep (antibody representations), takes the tokenized sequence and create hidden_embed (representations).
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
n_attn_heads,
n_encoder_blocks,
padding_tkn,
mask_tkn,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
dropout: float = 0.1,
):
super().__init__()
self.padding_tkn = padding_tkn
self.mask_tkn = mask_tkn
self.aa_embed_layer = nn.Embedding(
vocab_size,
hidden_embed_size,
padding_idx=padding_tkn,
)
self.encoder_blocks = nn.ModuleList(
[TransformerEncoder(
hidden_embed_size,
n_attn_heads,
attn_dropout = dropout,
layer_norm_eps = layer_norm_eps,
a_fn = a_fn,
) for _ in range(n_encoder_blocks)]
)
self.layer_norm_after_encoder_blocks = nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps)
def forward(self,
tokens,
return_attn_weights=False,
return_rep_layers=[],
):
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_tkn)
hidden_embed = self.aa_embed_layer(tokens)
return_rep_layers = set(return_rep_layers)
rep_layers = {}
if 0 in return_rep_layers: rep_layers[0] = hidden_embed
all_attn_weights = []
for n_layer, encoder_block in enumerate(self.encoder_blocks):
hidden_embed, attn_weights = encoder_block(hidden_embed, padding_mask, return_attn_weights)
if (n_layer + 1) in return_rep_layers:
rep_layers[n_layer + 1] = hidden_embed
if return_attn_weights:
all_attn_weights.append(attn_weights)
hidden_embed = self.layer_norm_after_encoder_blocks(hidden_embed)
return DataAbRep(
last_hidden_states=hidden_embed,
many_hidden_states=rep_layers,
attention_weights=all_attn_weights
)
class AbHead(torch.nn.Module):
"""
AbHead (antibody head model), creates amino acid probabilities for each position based on the hidden_embed (representations).
"""
def __init__(
self,
vocab_size,
hidden_embed_size,
weights,
layer_norm_eps: float = 1e-12,
a_fn: str = "gelu",
):
super().__init__()
activation_fn, scale = get_activation_fn(a_fn)
self.ff = torch.nn.Sequential(
nn.Linear(hidden_embed_size, hidden_embed_size * scale),
activation_fn(),
nn.LayerNorm(hidden_embed_size, eps=layer_norm_eps),
)
self.weights = weights
self.bias = nn.Parameter(torch.zeros(vocab_size))
def forward(self, hidden_embed):
hidden_embed = self.ff(hidden_embed)
logits = F.linear(hidden_embed, self.weights) + self.bias
return logits
@dataclass
class DataAbRep():
"""
Dataclass used to store AbRep output.
"""
last_hidden_states: torch.FloatTensor
many_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attention_weights: Optional[Tuple[torch.FloatTensor]] = None