bertose-iar-resolver / src /bertose_layers.py
supanthadey1's picture
Apply BERTose IAR display capitalization
b23b9d6 verified
"""
BERTose transformer layers.
Transformer blocks adapted for WURCS glycan tokenization.
"""
import torch
import torch.nn as nn
import math
class GlycanBERTConfig:
"""Configuration for the BERTose transformer stack."""
def __init__(
self,
vocab_size: int = 102,
hidden_size: int = 384,
num_hidden_layers: int = 6,
num_attention_heads: int = 6,
intermediate_size: int = 1536,
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
max_position_embeddings: int = 512,
layer_norm_eps: float = 1e-12,
pad_token_id: int = 0,
mask_token_id: int = 4,
initializer_range: float = 0.02
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.pad_token_id = pad_token_id
self.mask_token_id = mask_token_id
self.initializer_range = initializer_range
class GlycanBERTEmbeddings(nn.Module):
"""
Embeddings for glycan tokens including token and positional embeddings.
"""
def __init__(self, config: GlycanBERTConfig):
super().__init__()
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, max_seq_len) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: Tensor of shape (batch_size, seq_len)
Returns:
Embeddings of shape (batch_size, seq_len, hidden_size)
"""
batch_size, seq_len = input_ids.shape
# Token embeddings
token_embeds = self.token_embeddings(input_ids)
# Position embeddings
position_ids = self.position_ids[:, :seq_len]
position_embeds = self.position_embeddings(position_ids)
# Combine
embeddings = token_embeds + position_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class GlycanBERTAttention(nn.Module):
"""Multi-head self-attention."""
def __init__(self, config: GlycanBERTConfig):
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape for multi-head attention."""
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_shape)
return x.permute(0, 2, 1, 3) # (batch, heads, seq_len, head_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size, seq_len, hidden_size)
attention_mask: (batch_size, seq_len) - 1 for valid, 0 for padding
Returns:
Attention output: (batch_size, seq_len, hidden_size)
"""
batch_size, seq_len, _ = hidden_states.shape
# Linear projections
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# Attention scores
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply attention mask
if attention_mask is not None:
# Convert mask to additive mask
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
attention_mask = (1.0 - attention_mask) * -10000.0
attention_scores = attention_scores + attention_mask
# Attention probabilities
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
# Apply attention to values
context_layer = torch.matmul(attention_probs, value_layer)
# Reshape back
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_shape)
return context_layer
class GlycanBERTLayer(nn.Module):
"""Single transformer layer."""
def __init__(self, config: GlycanBERTConfig):
super().__init__()
self.attention = GlycanBERTAttention(config)
self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
self.output = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size, seq_len, hidden_size)
attention_mask: (batch_size, seq_len)
Returns:
Output: (batch_size, seq_len, hidden_size)
"""
# Self-attention
attention_output = self.attention(hidden_states, attention_mask)
attention_output = self.attention_output(attention_output)
attention_output = self.dropout(attention_output)
# Add & Norm
hidden_states = self.attention_layer_norm(hidden_states + attention_output)
# Feed-forward
intermediate_output = self.intermediate(hidden_states)
intermediate_output = nn.functional.gelu(intermediate_output)
layer_output = self.output(intermediate_output)
layer_output = self.dropout(layer_output)
# Add & Norm
layer_output = self.output_layer_norm(hidden_states + layer_output)
return layer_output
class GlycanBERT(nn.Module):
"""
BERTose transformer stack for masked language modeling.
"""
def __init__(self, config: GlycanBERTConfig):
super().__init__()
self.config = config
# Embeddings
self.embeddings = GlycanBERTEmbeddings(config)
# Transformer layers
self.layers = nn.ModuleList([GlycanBERTLayer(config) for _ in range(config.num_hidden_layers)])
# MLM head
self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None
):
"""
Args:
input_ids: (batch_size, seq_len)
attention_mask: (batch_size, seq_len) - 1 for valid, 0 for padding
labels: (batch_size, seq_len) - token IDs to predict, -100 for positions to ignore
Returns:
If labels provided: (loss, logits)
Else: logits
"""
# Create attention mask if not provided
if attention_mask is None:
attention_mask = (input_ids != self.config.pad_token_id).float()
# Embeddings
hidden_states = self.embeddings(input_ids)
# Transformer layers
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
# MLM prediction
logits = self.mlm_head(hidden_states)
# Calculate loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss() # -100 is ignored
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if loss is not None:
return loss, logits
return logits
def get_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
Get contextualized embeddings (for downstream tasks).
Args:
input_ids: (batch_size, seq_len)
attention_mask: (batch_size, seq_len)
Returns:
Embeddings: (batch_size, seq_len, hidden_size)
"""
if attention_mask is None:
attention_mask = (input_ids != self.config.pad_token_id).float()
hidden_states = self.embeddings(input_ids)
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
return hidden_states