PeptideMLM_base / ChemPepMTR.py
aaronfeller's picture
initial commit
eff0d52 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtune.modules import RotaryPositionalEmbeddings
from transformers import PreTrainedModel
from .config import model_config
from typing import Mapping
from transformers.tokenization_utils_base import BatchEncoding
class SwiGLU(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim * 2, bias=True)
self.linear2 = nn.Linear(hidden_dim, input_dim, bias=True)
self.dropout = nn.Dropout(0.1) # Add dropout for regularization
def forward(self, x):
# x: (N, input_dim)
x1, x2 = self.linear1(x).chunk(2, dim=-1)
output = self.linear2(F.silu(x1) * x2)
return self.dropout(output)
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, max_seq_len):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
self.rotary = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=max_seq_len)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.dropout = nn.Dropout(0.1) # Add dropout for regularization
def forward(self, x, input_pos=None, mask=None):
B, T, C = x.shape # Batch, sequence, embedding dim
# project into queries, keys, and values
q, k, v = self.qkv_proj(x).view(B, T, 3, self.num_heads, self.head_dim).unbind(2) # (B, T, num_heads, head_dim)
# Apply rotary positional embeddings to queries and keys
q, k = self.rotary(q, input_pos=input_pos), self.rotary(k, input_pos=input_pos)
# Reshape to (B, num_heads, T, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if mask is not None:
# set padding positions to -inf
mask = mask.to(dtype=torch.float32) # Ensure mask is float
mask = (1.0 - mask) * -1e9 # Convert to -inf for padding positions
# mask: (B, T) -> (B, 1, 1, T)
mask = mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
mask = mask.expand(B, 1, T, T) # expands to (batch, 1, seqlen, seqlen)
# Scaled dot-product attention
attn_output = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
attn_output = self.out_proj(attn_output)
return self.dropout(attn_output)
class UnifiedTransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ffn_hidden_dim, max_seq_len):
super().__init__()
self.attn_norm = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, max_seq_len)
self.ffn_norm = nn.LayerNorm(embed_dim)
self.ffn = SwiGLU(embed_dim, ffn_hidden_dim)
def forward(self, x, input_pos=None, mask=None):
x = x + self.attn(self.attn_norm(x), input_pos=input_pos, mask=mask)
x = x + self.ffn(self.ffn_norm(x))
return x
class TransformerStack(nn.Module):
def __init__(self, num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len):
super().__init__()
self.blocks = nn.ModuleList([
UnifiedTransformerBlock(embed_dim, num_heads, ffn_hidden_dim, max_seq_len)
for _ in range(num_blocks)
])
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x, input_pos=None, mask=None):
for block in self.blocks:
x = block(x, input_pos=input_pos, mask=mask)
return self.norm(x)
class MLM_core(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int,
num_blocks: int,
num_heads: int,
ffn_hidden_dim: int,
output_dim: int,
max_seq_len: int,
):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.transformer = TransformerStack(
num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len
)
self.sequence_head = nn.Linear(embed_dim, output_dim, bias=True)
def forward(self, ids, mask=None, pad_token_id=0, input_pos=None):
x = self.embed(ids)
x = self.transformer(x, mask=mask, input_pos=input_pos)
# generate logits for MLM
# print(f"x shape: {x.shape}") # Debugging line to check the shape of x
logits = self.sequence_head(x)
# print(f"logits shape: {logits.shape}") # Debugging line to check the shape of logits
# mean pool but remove positions that have pad tokens
mean_pool = x.masked_fill(ids.unsqueeze(-1) == pad_token_id, 0).mean(dim=1)
outputs = {
'logits': logits,
'last_layer': x,
'mean_pool': mean_pool
}
return outputs
class MLM_model(PreTrainedModel): # HF-facing class name
config_class = model_config
def __init__(self, config):
super().__init__(config)
self.model = MLM_core(
vocab_size=config.vocab_size,
embed_dim=config.embed_dim,
num_blocks=config.num_blocks,
num_heads=config.num_heads,
ffn_hidden_dim=config.ffn_hidden_dim,
output_dim=config.output_dim,
max_seq_len=config.max_seq_len,
)
self.post_init() # Initialize weights and apply final processing
# if inputs are dictionary
def forward(self, x=None, **kwargs):
if isinstance(x, (BatchEncoding, Mapping)):
return self.model(x.get("input_ids"), mask=x.get("attention_mask"))
if "input_ids" in kwargs or "attention_mask" in kwargs:
return self.model(kwargs.get("input_ids"), mask=kwargs.get("attention_mask"))
return self.model(x)