helm-bert / modeling_helmbert.py
Flansma's picture
Upload folder using huggingface_hub
62249d1 verified
"""HELM-BERT model implementation.
This module implements the HELM-BERT model with:
- Disentangled attention (DeBERTa-style)
- Enhanced Mask Decoder (EMD) for MLM
- n-gram Induced Encoding (nGiE) layer
"""
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from packaging import version
from torch import _softmax_backward_data
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
MaskedLMOutput,
SequenceClassifierOutput,
)
from .configuration_helmbert import HELMBertConfig
# -----------------------------------------------------------------------------
# Utility Functions
# -----------------------------------------------------------------------------
def masked_layer_norm(
layer_norm: nn.LayerNorm, x: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Apply LayerNorm with masking to avoid updates on padding tokens.
Args:
layer_norm: LayerNorm module
x: Input tensor (batch_size, seq_len, hidden_size)
mask: Mask tensor where 0 = padding (ignored), 1 = valid token
Returns:
Normalized tensor with padding positions zeroed out
"""
output = layer_norm(x).to(x.dtype)
if mask is None:
return output
if mask.dim() != x.dim():
if mask.dim() == 4:
mask = mask.squeeze(1).squeeze(1)
mask = mask.unsqueeze(2)
mask = mask.to(output.dtype)
return output * mask
class XSoftmax(torch.autograd.Function):
"""Masked Softmax optimized for memory efficiency."""
@staticmethod
def forward(
ctx, input: torch.Tensor, mask: Optional[torch.Tensor], dim: int
) -> torch.Tensor:
ctx.dim = dim
if mask is not None:
rmask = ~(mask.bool())
if rmask.dim() == 2:
rmask = rmask.unsqueeze(1).unsqueeze(2)
elif rmask.dim() == 3:
rmask = rmask.unsqueeze(2)
output = input.masked_fill(rmask, float("-inf"))
else:
output = input
output = torch.softmax(output, ctx.dim)
if mask is not None:
output.masked_fill_(rmask, 0)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
(output,) = ctx.saved_tensors
if version.Version(torch.__version__) >= version.Version("1.11.0"):
input_grad = _softmax_backward_data(
grad_output, output, ctx.dim, output.dtype
)
else:
input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output)
return input_grad, None, None
def build_relative_position(
query_size: int,
key_size: int,
bucket_size: int = -1,
max_position: int = 512,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Build relative position matrix with optional log-bucketing."""
q_ids = torch.arange(query_size, dtype=torch.long, device=device)
k_ids = torch.arange(key_size, dtype=torch.long, device=device)
rel_pos = q_ids.unsqueeze(1) - k_ids.unsqueeze(0)
if bucket_size > 0:
rel_buckets = 0
num_buckets = bucket_size
rel_buckets += (rel_pos > 0).long() * (num_buckets // 2)
rel_pos = torch.abs(rel_pos)
max_exact = num_buckets // 4
is_small = rel_pos < max_exact
rel_pos_if_large = (
max_exact
+ (
torch.log(rel_pos.float() / max_exact)
/ math.log(max_position / max_exact)
* (num_buckets // 4 - 1)
).long()
)
rel_pos_if_large = torch.min(
rel_pos_if_large, torch.full_like(rel_pos_if_large, num_buckets // 2 - 1)
)
rel_buckets += torch.where(is_small, rel_pos, rel_pos_if_large)
return rel_buckets
else:
rel_pos = torch.clamp(rel_pos, -max_position, max_position)
return rel_pos + max_position
# -----------------------------------------------------------------------------
# Attention Modules
# -----------------------------------------------------------------------------
class DisentangledSelfAttention(nn.Module):
"""Disentangled self-attention with content and position separation.
Implements content-to-content, content-to-position, and position-to-content
attention as described in DeBERTa.
"""
def __init__(self, config: HELMBertConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"hidden_size ({config.hidden_size}) must be divisible by "
f"num_attention_heads ({config.num_attention_heads})"
)
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_heads * self.head_size
# Content projections
self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
# Position attention configuration
self.pos_att_type = [x.strip() for x in config.pos_att_type.lower().split("|")]
self.max_relative_positions = config.max_relative_positions
self.position_buckets = config.position_buckets
self.share_att_key = config.share_att_key
# Position embedding size
self.pos_ebd_size = config.max_relative_positions
if config.position_buckets > 0:
self.pos_ebd_size = config.position_buckets
# Position embeddings
self.rel_embeddings = nn.Embedding(self.pos_ebd_size * 2, config.hidden_size)
# Position projections
if not self.share_att_key:
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_key_proj = nn.Linear(
config.hidden_size, self.all_head_size, bias=True
)
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
self.pos_query_proj = nn.Linear(
config.hidden_size, self.all_head_size, bias=False
)
# Dropout
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.pos_dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape tensor for attention computation."""
new_shape = x.size()[:-1] + (self.num_heads, self.head_size)
x = x.view(*new_shape)
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
query_states: Optional[torch.Tensor] = None,
relative_pos: Optional[torch.Tensor] = None,
rel_embeddings: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
"""Forward pass of disentangled attention."""
if query_states is None:
query_states = hidden_states
# Compute Q, K, V
query_layer = self.transpose_for_scores(self.query_proj(query_states)).float()
key_layer = self.transpose_for_scores(self.key_proj(hidden_states)).float()
value_layer = self.transpose_for_scores(self.value_proj(hidden_states))
# Calculate scale factor
scale_factor = 1
if "c2p" in self.pos_att_type:
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
if "p2p" in self.pos_att_type:
scale_factor += 1
scale = 1.0 / math.sqrt(self.head_size * scale_factor)
# Content-to-content attention (c2c)
c2c_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) * scale)
attention_scores = c2c_scores
# Add relative position bias if enabled
if len(self.pos_att_type) > 0 and self.pos_att_type[0]:
rel_att = self._disentangled_attention_bias(
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
)
if rel_att is not None:
attention_scores = attention_scores + rel_att
# Normalize scores for numerical stability
attention_scores = (
attention_scores - attention_scores.max(dim=-1, keepdim=True)[0].detach()
)
attention_scores = attention_scores.to(hidden_states.dtype)
# Reshape for XSoftmax
attention_scores = attention_scores.view(
-1, self.num_heads, attention_scores.size(-2), attention_scores.size(-1)
)
# Apply XSoftmax
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
attention_probs = self.dropout(attention_probs)
# Apply attention to values
attention_probs_flat = attention_probs.view(
-1, attention_probs.size(-2), attention_probs.size(-1)
)
context_layer = torch.bmm(attention_probs_flat, value_layer)
# Reshape output
context_layer = context_layer.view(
-1, self.num_heads, context_layer.size(-2), context_layer.size(-1)
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_shape)
result = {"hidden_states": context_layer, "attention_probs": attention_probs}
return result
def _disentangled_attention_bias(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
relative_pos: Optional[torch.Tensor],
rel_embeddings: Optional[torch.Tensor],
scale_factor: int,
) -> Optional[torch.Tensor]:
"""Compute disentangled attention bias."""
if relative_pos is None:
q_size = query_layer.size(-2)
k_size = key_layer.size(-2)
relative_pos = build_relative_position(
q_size,
k_size,
bucket_size=self.position_buckets,
max_position=self.max_relative_positions,
device=query_layer.device,
)
if relative_pos.dim() == 2:
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
elif relative_pos.dim() == 3:
relative_pos = relative_pos.unsqueeze(1)
batch_size = query_layer.size(0) // self.num_heads
# Get position embeddings
if rel_embeddings is None:
rel_embeddings = self.rel_embeddings.weight
att_span = self.pos_ebd_size
rel_embeddings = rel_embeddings[
self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :
].unsqueeze(0)
rel_embeddings = self.pos_dropout(rel_embeddings)
score = torch.zeros_like(query_layer[:, :, :1]).expand(
-1, -1, key_layer.size(-2)
)
# Prepare position indices
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_pos = c2p_pos.squeeze(0).expand(
query_layer.size(0), query_layer.size(1), relative_pos.size(-1)
)
# Content-to-position (c2p)
if "c2p" in self.pos_att_type:
pos_key_layer = (
self.pos_key_proj(rel_embeddings)
if not self.share_att_key
else self.key_proj(rel_embeddings)
)
pos_key_layer = self.transpose_for_scores(pos_key_layer).repeat(
batch_size, 1, 1
)
c2p_scale = 1.0 / math.sqrt(self.head_size * scale_factor)
c2p_att = torch.bmm(
query_layer, pos_key_layer.transpose(-1, -2) * c2p_scale
)
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos)
score = score + c2p_att
# Position-to-content (p2c)
if "p2c" in self.pos_att_type:
pos_query_layer = (
self.pos_query_proj(rel_embeddings)
if not self.share_att_key
else self.query_proj(rel_embeddings)
)
pos_query_layer = self.transpose_for_scores(pos_query_layer).repeat(
batch_size, 1, 1
)
p2c_scale = 1.0 / math.sqrt(self.head_size * scale_factor)
p2c_att = torch.bmm(
pos_query_layer * p2c_scale, key_layer.transpose(-1, -2)
)
p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos)
score = score + p2c_att
return score
# -----------------------------------------------------------------------------
# Transformer Components
# -----------------------------------------------------------------------------
class HELMBertEmbeddings(nn.Module):
"""Token and position embeddings for HELM-BERT."""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Returns:
Tuple of (token_embeddings, position_embeddings)
"""
batch_size, seq_len = input_ids.shape
# Token embeddings
embeddings = self.word_embeddings(input_ids)
# Position embeddings
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
position_embeds = self.position_embeddings(position_ids)
# Layer norm and dropout
embeddings = masked_layer_norm(self.layer_norm, embeddings, attention_mask)
embeddings = self.dropout(embeddings)
return embeddings, position_embeds
class NgieLayer(nn.Module):
"""n-gram Induced Input Encoding (nGiE) layer.
Captures local n-gram patterns using 1D convolution.
"""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.conv = nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=config.ngie_kernel_size,
padding=(config.ngie_kernel_size - 1) // 2,
groups=1,
)
self.activation = nn.Tanh()
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.dropout = nn.Dropout(config.ngie_dropout)
def forward(
self,
hidden_states: torch.Tensor,
residual_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Forward pass.
Args:
hidden_states: Input to convolution (batch, seq, hidden)
residual_states: States for residual connection (batch, seq, hidden)
attention_mask: Mask where 1 = valid, 0 = padding
Returns:
Output with n-gram information incorporated
"""
# Apply 1D convolution
out = (
self.conv(hidden_states.permute(0, 2, 1).contiguous())
.permute(0, 2, 1)
.contiguous()
)
# Create reverse mask for padding
if version.Version(torch.__version__) >= version.Version("1.2.0a"):
rmask = (1 - attention_mask).bool()
else:
rmask = (1 - attention_mask).byte()
# Zero out padding positions
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
# Apply activation and dropout
out = self.activation(self.dropout(out))
# Residual connection with LayerNorm
output_states = masked_layer_norm(
self.layer_norm, residual_states + out, attention_mask
)
return output_states
class TransformerBlock(nn.Module):
"""Transformer block with disentangled attention and GELU FFN."""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.self_attn = DisentangledSelfAttention(config)
self.attn_output_dense = nn.Linear(config.hidden_size, config.hidden_size)
# FFN with GELU
self.linear1 = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU()
)
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
# Normalization and dropout
self.norm1 = nn.LayerNorm(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size)
self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
src: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
query_states: Optional[torch.Tensor] = None,
relative_pos: Optional[torch.Tensor] = None,
rel_embeddings: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass.
Args:
src: Input embeddings [seq_len, batch, hidden]
src_key_padding_mask: Padding mask [batch, seq_len]
output_attentions: Whether to return attention weights
query_states: Optional query for EMD
relative_pos: Relative position indices
rel_embeddings: Relative position embeddings
Returns:
Tuple of (output, optional attention weights)
"""
# Transpose for attention [seq, batch, hidden] -> [batch, seq, hidden]
src_transposed = src.transpose(0, 1)
# Convert padding mask to attention mask (1=valid, 0=padding)
attention_mask = None
if src_key_padding_mask is not None:
attention_mask = (~src_key_padding_mask).float()
query_states_transposed = None
if query_states is not None:
query_states_transposed = query_states.transpose(0, 1)
# Self-attention
attn_result = self.self_attn(
src_transposed,
attention_mask,
output_attentions=output_attentions,
query_states=query_states_transposed,
relative_pos=relative_pos,
rel_embeddings=rel_embeddings,
)
attn_output = attn_result["hidden_states"].transpose(0, 1)
attn_weights = attn_result.get("attention_probs") if output_attentions else None
# Dense projection
attn_output = self.attn_output_dense(attn_output)
# Residual connection
residual_input = query_states if query_states is not None else src
src = residual_input + self.dropout1(attn_output)
# LayerNorm
src = src.transpose(0, 1)
src = masked_layer_norm(self.norm1, src)
src = src.transpose(0, 1)
# FFN
ff_output = self.linear1(src)
ff_output = self.linear2(ff_output)
ff_output = self.dropout2(ff_output)
src = src + ff_output
# LayerNorm
src = src.transpose(0, 1)
src = masked_layer_norm(self.norm2, src)
src = src.transpose(0, 1)
return src, attn_weights
class HELMBertEncoder(nn.Module):
"""Stack of transformer blocks with nGiE layer."""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.config = config
# nGiE layer (applied after first transformer block)
self.ngie_layer = NgieLayer(config)
# Transformer blocks
self.layers = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.num_hidden_layers)]
)
def get_rel_embedding(self) -> Optional[torch.Tensor]:
"""Get relative position embeddings from first layer."""
if len(self.layers) > 0:
first_layer = self.layers[0]
if hasattr(first_layer, "self_attn") and hasattr(
first_layer.self_attn, "rel_embeddings"
):
return first_layer.self_attn.rel_embeddings.weight
return None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
use_emd: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], Optional[Tuple]]:
"""Forward pass.
Args:
hidden_states: Input embeddings [batch, seq, hidden]
attention_mask: Attention mask [batch, seq]
position_embeddings: Position embeddings for EMD
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
use_emd: Whether to use Enhanced Mask Decoder
Returns:
Tuple of (last_hidden_state, emd_output, all_hidden_states, all_attentions)
"""
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# Store for nGiE
ngie_input_states = hidden_states
# [batch, seq, hidden] -> [seq, batch, hidden]
hidden_states = hidden_states.transpose(0, 1)
# Key padding mask (True = padding)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = ~attention_mask.bool()
# Store layer[-2] for EMD
layer_minus_2 = None
num_layers = len(self.layers)
for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),)
hidden_states, attn_weights = layer(
hidden_states,
src_key_padding_mask=key_padding_mask,
output_attentions=output_attentions,
)
if output_attentions and attn_weights is not None:
all_attentions = all_attentions + (attn_weights,)
# Apply nGiE after first layer
if layer_idx == 0:
hidden_states_batch = hidden_states.transpose(0, 1)
hidden_states_batch = self.ngie_layer(
ngie_input_states, hidden_states_batch, attention_mask
)
hidden_states = hidden_states_batch.transpose(0, 1)
# Store layer[-2] for EMD
if use_emd and layer_idx == num_layers - 2:
layer_minus_2 = hidden_states
# Convert back to [batch, seq, hidden]
hidden_states = hidden_states.transpose(0, 1)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# Enhanced Mask Decoder (EMD) for MLM
emd_output = None
if use_emd and layer_minus_2 is not None and position_embeddings is not None:
emd_keys_values = layer_minus_2
emd_query = layer_minus_2.transpose(0, 1)
emd_query = position_embeddings + emd_query
emd_query = emd_query.transpose(0, 1)
rel_embeddings = self.get_rel_embedding()
last_layer = self.layers[-1]
for _ in range(2):
emd_query, _ = last_layer(
emd_keys_values,
src_key_padding_mask=key_padding_mask,
query_states=emd_query,
relative_pos=None,
rel_embeddings=rel_embeddings,
)
emd_output = emd_query.transpose(0, 1)
return hidden_states, emd_output, all_hidden_states, all_attentions
class HELMBertPooler(nn.Module):
"""Mean pooling over sequence."""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.hidden_size = config.hidden_size
def forward(
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Apply mean pooling.
Args:
hidden_states: [batch, seq, hidden]
attention_mask: [batch, seq]
Returns:
Pooled output [batch, hidden]
"""
if attention_mask is not None:
mask_expanded = (
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
)
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
eps = torch.finfo(hidden_states.dtype).eps
sum_mask = torch.clamp(mask_expanded.sum(1), min=eps)
return sum_embeddings / sum_mask
else:
return hidden_states.mean(dim=1)
# -----------------------------------------------------------------------------
# Pre-trained Model Base
# -----------------------------------------------------------------------------
class HELMBertPreTrainedModel(PreTrainedModel):
"""Base class for HELM-BERT models."""
config_class = HELMBertConfig
base_model_prefix = "helmbert"
def _init_weights(self, module: nn.Module) -> None:
"""Initialize weights with BERT-style initialization."""
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# -----------------------------------------------------------------------------
# Model Classes
# -----------------------------------------------------------------------------
class HELMBertModel(HELMBertPreTrainedModel):
"""HELM-BERT base model.
This model outputs the last hidden states and optionally pooled output.
Example:
>>> from helmbert import HELMBertModel, HELMBertTokenizer
>>> tokenizer = HELMBertTokenizer()
>>> model = HELMBertModel.from_pretrained("./checkpoints/helmbert-base")
>>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooler_output = outputs.pooler_output
"""
def __init__(self, config: HELMBertConfig):
super().__init__(config)
self.config = config
self.embeddings = HELMBertEmbeddings(config)
self.encoder = HELMBertEncoder(config)
self.pooler = HELMBertPooler(config)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings.word_embeddings
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""Forward pass.
Args:
input_ids: Token IDs [batch, seq]
attention_mask: Attention mask [batch, seq]
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
return_dict: Whether to return a ModelOutput
Returns:
BaseModelOutputWithPooling or tuple
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# Embeddings
embeddings, position_embeddings = self.embeddings(input_ids, attention_mask)
# Encoder
encoder_outputs = self.encoder(
embeddings,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_emd=False,
)
last_hidden_state = encoder_outputs[0]
hidden_states = encoder_outputs[2]
attentions = encoder_outputs[3]
# Pooling
pooler_output = self.pooler(last_hidden_state, attention_mask)
if not return_dict:
return (last_hidden_state, pooler_output, hidden_states, attentions)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)
class HELMBertLMHead(nn.Module):
"""MLM head with weight tying (HuggingFace standard)."""
def __init__(self, config: HELMBertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.activation = nn.GELU()
# Decoder with weight tying (weight tied to embedding, bias is separate)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
hidden_states: [batch, seq, hidden]
Returns:
Logits [batch, seq, vocab]
"""
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.layer_norm(hidden_states)
logits = self.decoder(hidden_states)
return logits
class HELMBertForMaskedLM(HELMBertPreTrainedModel):
"""HELM-BERT for Masked Language Modeling with Enhanced Mask Decoder (EMD).
Example:
>>> from helmbert import HELMBertForMaskedLM, HELMBertTokenizer
>>> tokenizer = HELMBertTokenizer()
>>> model = HELMBertForMaskedLM.from_pretrained("./checkpoints/helmbert-base")
>>> inputs = tokenizer("PEPTIDE1{A.¶.D.E}$$$$", return_tensors="pt") # ¶ is mask
>>> outputs = model(**inputs)
>>> predictions = outputs.logits.argmax(dim=-1)
"""
_tied_weights_keys = ["lm_head.decoder.weight"]
def __init__(self, config: HELMBertConfig):
super().__init__(config)
self.helmbert = HELMBertModel(config)
self.lm_head = HELMBertLMHead(config)
self.post_init()
def get_output_embeddings(self) -> nn.Linear:
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.lm_head.decoder = new_embeddings
def get_input_embeddings(self) -> nn.Embedding:
return self.helmbert.embeddings.word_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
use_emd: bool = True,
) -> Union[Tuple, MaskedLMOutput]:
"""Forward pass.
Args:
input_ids: Token IDs [batch, seq]
attention_mask: Attention mask [batch, seq]
labels: Labels for MLM (-100 for non-masked tokens)
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
return_dict: Whether to return a ModelOutput
use_emd: Whether to use Enhanced Mask Decoder
Returns:
MaskedLMOutput or tuple
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
# Embeddings
embeddings, position_embeddings = self.helmbert.embeddings(
input_ids, attention_mask
)
# Encoder with optional EMD
encoder_outputs = self.helmbert.encoder(
embeddings,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_emd=use_emd,
)
# Use EMD output if available, otherwise use last hidden state
if use_emd and encoder_outputs[1] is not None:
sequence_output = encoder_outputs[1]
else:
sequence_output = encoder_outputs[0]
hidden_states = encoder_outputs[2]
attentions = encoder_outputs[3]
# MLM head (weight tying handled by HuggingFace)
prediction_scores = self.lm_head(sequence_output)
# Calculate loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
)
if not return_dict:
output = (prediction_scores, hidden_states, attentions)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=hidden_states,
attentions=attentions,
)
class MLPHead(nn.Module):
"""MLP head with skip connections for classification/regression.
Architecture: input -> [Linear -> GELU -> LayerNorm -> Dropout (+ skip)] x N -> Linear -> output
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dims: list,
dropout: float = 0.1,
):
super().__init__()
self.layers = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropouts = nn.ModuleList()
prev_dim = input_dim
for hidden_dim in hidden_dims:
self.layers.append(nn.Linear(prev_dim, hidden_dim))
self.norms.append(nn.LayerNorm(hidden_dim))
self.dropouts.append(nn.Dropout(dropout))
prev_dim = hidden_dim
self.output_layer = nn.Linear(prev_dim, output_dim)
self.activation = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer, norm, dropout in zip(self.layers, self.norms, self.dropouts):
identity = x
x = layer(x)
if x.shape == identity.shape:
x = x + identity # Skip connection
x = self.activation(x)
x = norm(x)
x = dropout(x)
return self.output_layer(x)
class HELMBertForSequenceClassification(HELMBertPreTrainedModel):
"""HELM-BERT for sequence classification/regression.
Example:
>>> from helmbert import HELMBertForSequenceClassification, HELMBertConfig
>>> # Simple linear head (default)
>>> config = HELMBertConfig(num_labels=1)
>>> model = HELMBertForSequenceClassification(config)
>>>
>>> # MLP head with 2 layers (for permeability prediction)
>>> config = HELMBertConfig(num_labels=1, classifier_num_layers=2)
>>> model = HELMBertForSequenceClassification(config)
"""
def __init__(self, config: HELMBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.helmbert = HELMBertModel(config)
# Use MLP head if num_layers > 0, otherwise simple linear
if config.classifier_num_layers > 0:
hidden_dims = [config.hidden_size] * config.classifier_num_layers
self.classifier = MLPHead(
input_dim=config.hidden_size,
output_dim=config.num_labels,
hidden_dims=hidden_dims,
dropout=config.classifier_dropout,
)
else:
self.dropout = nn.Dropout(config.classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[Tuple, SequenceClassifierOutput]:
"""Forward pass.
Args:
input_ids: Token IDs [batch, seq]
attention_mask: Attention mask [batch, seq]
labels: Labels for classification/regression
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return all hidden states
return_dict: Whether to return a ModelOutput
Returns:
SequenceClassifierOutput or tuple
"""
outputs = self.helmbert(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
pooled_output = outputs.pooler_output
# MLP head has internal dropout, simple linear needs separate dropout
if hasattr(self, "dropout"):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)