|
|
"""HELM-BERT model implementation. |
|
|
|
|
|
This module implements the HELM-BERT model with: |
|
|
- Disentangled attention (DeBERTa-style) |
|
|
- Enhanced Mask Decoder (EMD) for MLM |
|
|
- n-gram Induced Encoding (nGiE) layer |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from packaging import version |
|
|
from torch import _softmax_backward_data |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPooling, |
|
|
MaskedLMOutput, |
|
|
SequenceClassifierOutput, |
|
|
) |
|
|
|
|
|
from .configuration_helmbert import HELMBertConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def masked_layer_norm( |
|
|
layer_norm: nn.LayerNorm, x: torch.Tensor, mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
"""Apply LayerNorm with masking to avoid updates on padding tokens. |
|
|
|
|
|
Args: |
|
|
layer_norm: LayerNorm module |
|
|
x: Input tensor (batch_size, seq_len, hidden_size) |
|
|
mask: Mask tensor where 0 = padding (ignored), 1 = valid token |
|
|
|
|
|
Returns: |
|
|
Normalized tensor with padding positions zeroed out |
|
|
""" |
|
|
output = layer_norm(x).to(x.dtype) |
|
|
if mask is None: |
|
|
return output |
|
|
if mask.dim() != x.dim(): |
|
|
if mask.dim() == 4: |
|
|
mask = mask.squeeze(1).squeeze(1) |
|
|
mask = mask.unsqueeze(2) |
|
|
mask = mask.to(output.dtype) |
|
|
return output * mask |
|
|
|
|
|
|
|
|
class XSoftmax(torch.autograd.Function): |
|
|
"""Masked Softmax optimized for memory efficiency.""" |
|
|
|
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx, input: torch.Tensor, mask: Optional[torch.Tensor], dim: int |
|
|
) -> torch.Tensor: |
|
|
ctx.dim = dim |
|
|
if mask is not None: |
|
|
rmask = ~(mask.bool()) |
|
|
if rmask.dim() == 2: |
|
|
rmask = rmask.unsqueeze(1).unsqueeze(2) |
|
|
elif rmask.dim() == 3: |
|
|
rmask = rmask.unsqueeze(2) |
|
|
output = input.masked_fill(rmask, float("-inf")) |
|
|
else: |
|
|
output = input |
|
|
output = torch.softmax(output, ctx.dim) |
|
|
if mask is not None: |
|
|
output.masked_fill_(rmask, 0) |
|
|
ctx.save_for_backward(output) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: |
|
|
(output,) = ctx.saved_tensors |
|
|
if version.Version(torch.__version__) >= version.Version("1.11.0"): |
|
|
input_grad = _softmax_backward_data( |
|
|
grad_output, output, ctx.dim, output.dtype |
|
|
) |
|
|
else: |
|
|
input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output) |
|
|
return input_grad, None, None |
|
|
|
|
|
|
|
|
def build_relative_position( |
|
|
query_size: int, |
|
|
key_size: int, |
|
|
bucket_size: int = -1, |
|
|
max_position: int = 512, |
|
|
device: Optional[torch.device] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Build relative position matrix with optional log-bucketing.""" |
|
|
q_ids = torch.arange(query_size, dtype=torch.long, device=device) |
|
|
k_ids = torch.arange(key_size, dtype=torch.long, device=device) |
|
|
rel_pos = q_ids.unsqueeze(1) - k_ids.unsqueeze(0) |
|
|
|
|
|
if bucket_size > 0: |
|
|
rel_buckets = 0 |
|
|
num_buckets = bucket_size |
|
|
rel_buckets += (rel_pos > 0).long() * (num_buckets // 2) |
|
|
rel_pos = torch.abs(rel_pos) |
|
|
|
|
|
max_exact = num_buckets // 4 |
|
|
is_small = rel_pos < max_exact |
|
|
|
|
|
rel_pos_if_large = ( |
|
|
max_exact |
|
|
+ ( |
|
|
torch.log(rel_pos.float() / max_exact) |
|
|
/ math.log(max_position / max_exact) |
|
|
* (num_buckets // 4 - 1) |
|
|
).long() |
|
|
) |
|
|
rel_pos_if_large = torch.min( |
|
|
rel_pos_if_large, torch.full_like(rel_pos_if_large, num_buckets // 2 - 1) |
|
|
) |
|
|
|
|
|
rel_buckets += torch.where(is_small, rel_pos, rel_pos_if_large) |
|
|
return rel_buckets |
|
|
else: |
|
|
rel_pos = torch.clamp(rel_pos, -max_position, max_position) |
|
|
return rel_pos + max_position |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DisentangledSelfAttention(nn.Module): |
|
|
"""Disentangled self-attention with content and position separation. |
|
|
|
|
|
Implements content-to-content, content-to-position, and position-to-content |
|
|
attention as described in DeBERTa. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
|
|
|
if config.hidden_size % config.num_attention_heads != 0: |
|
|
raise ValueError( |
|
|
f"hidden_size ({config.hidden_size}) must be divisible by " |
|
|
f"num_attention_heads ({config.num_attention_heads})" |
|
|
) |
|
|
|
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_size = config.hidden_size // config.num_attention_heads |
|
|
self.all_head_size = self.num_heads * self.head_size |
|
|
|
|
|
|
|
|
self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
|
|
self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
|
|
self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) |
|
|
|
|
|
|
|
|
self.pos_att_type = [x.strip() for x in config.pos_att_type.lower().split("|")] |
|
|
self.max_relative_positions = config.max_relative_positions |
|
|
self.position_buckets = config.position_buckets |
|
|
self.share_att_key = config.share_att_key |
|
|
|
|
|
|
|
|
self.pos_ebd_size = config.max_relative_positions |
|
|
if config.position_buckets > 0: |
|
|
self.pos_ebd_size = config.position_buckets |
|
|
|
|
|
|
|
|
self.rel_embeddings = nn.Embedding(self.pos_ebd_size * 2, config.hidden_size) |
|
|
|
|
|
|
|
|
if not self.share_att_key: |
|
|
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: |
|
|
self.pos_key_proj = nn.Linear( |
|
|
config.hidden_size, self.all_head_size, bias=True |
|
|
) |
|
|
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: |
|
|
self.pos_query_proj = nn.Linear( |
|
|
config.hidden_size, self.all_head_size, bias=False |
|
|
) |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
self.pos_dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
|
|
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Reshape tensor for attention computation.""" |
|
|
new_shape = x.size()[:-1] + (self.num_heads, self.head_size) |
|
|
x = x.view(*new_shape) |
|
|
return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
query_states: Optional[torch.Tensor] = None, |
|
|
relative_pos: Optional[torch.Tensor] = None, |
|
|
rel_embeddings: Optional[torch.Tensor] = None, |
|
|
) -> Dict[str, Any]: |
|
|
"""Forward pass of disentangled attention.""" |
|
|
if query_states is None: |
|
|
query_states = hidden_states |
|
|
|
|
|
|
|
|
query_layer = self.transpose_for_scores(self.query_proj(query_states)).float() |
|
|
key_layer = self.transpose_for_scores(self.key_proj(hidden_states)).float() |
|
|
value_layer = self.transpose_for_scores(self.value_proj(hidden_states)) |
|
|
|
|
|
|
|
|
scale_factor = 1 |
|
|
if "c2p" in self.pos_att_type: |
|
|
scale_factor += 1 |
|
|
if "p2c" in self.pos_att_type: |
|
|
scale_factor += 1 |
|
|
if "p2p" in self.pos_att_type: |
|
|
scale_factor += 1 |
|
|
|
|
|
scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
|
|
|
|
|
|
|
|
c2c_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) * scale) |
|
|
attention_scores = c2c_scores |
|
|
|
|
|
|
|
|
if len(self.pos_att_type) > 0 and self.pos_att_type[0]: |
|
|
rel_att = self._disentangled_attention_bias( |
|
|
query_layer, key_layer, relative_pos, rel_embeddings, scale_factor |
|
|
) |
|
|
if rel_att is not None: |
|
|
attention_scores = attention_scores + rel_att |
|
|
|
|
|
|
|
|
attention_scores = ( |
|
|
attention_scores - attention_scores.max(dim=-1, keepdim=True)[0].detach() |
|
|
) |
|
|
attention_scores = attention_scores.to(hidden_states.dtype) |
|
|
|
|
|
|
|
|
attention_scores = attention_scores.view( |
|
|
-1, self.num_heads, attention_scores.size(-2), attention_scores.size(-1) |
|
|
) |
|
|
|
|
|
|
|
|
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) |
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
|
|
|
attention_probs_flat = attention_probs.view( |
|
|
-1, attention_probs.size(-2), attention_probs.size(-1) |
|
|
) |
|
|
context_layer = torch.bmm(attention_probs_flat, value_layer) |
|
|
|
|
|
|
|
|
context_layer = context_layer.view( |
|
|
-1, self.num_heads, context_layer.size(-2), context_layer.size(-1) |
|
|
) |
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
|
new_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
|
context_layer = context_layer.view(*new_shape) |
|
|
|
|
|
result = {"hidden_states": context_layer, "attention_probs": attention_probs} |
|
|
return result |
|
|
|
|
|
def _disentangled_attention_bias( |
|
|
self, |
|
|
query_layer: torch.Tensor, |
|
|
key_layer: torch.Tensor, |
|
|
relative_pos: Optional[torch.Tensor], |
|
|
rel_embeddings: Optional[torch.Tensor], |
|
|
scale_factor: int, |
|
|
) -> Optional[torch.Tensor]: |
|
|
"""Compute disentangled attention bias.""" |
|
|
if relative_pos is None: |
|
|
q_size = query_layer.size(-2) |
|
|
k_size = key_layer.size(-2) |
|
|
relative_pos = build_relative_position( |
|
|
q_size, |
|
|
k_size, |
|
|
bucket_size=self.position_buckets, |
|
|
max_position=self.max_relative_positions, |
|
|
device=query_layer.device, |
|
|
) |
|
|
|
|
|
if relative_pos.dim() == 2: |
|
|
relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) |
|
|
elif relative_pos.dim() == 3: |
|
|
relative_pos = relative_pos.unsqueeze(1) |
|
|
|
|
|
batch_size = query_layer.size(0) // self.num_heads |
|
|
|
|
|
|
|
|
if rel_embeddings is None: |
|
|
rel_embeddings = self.rel_embeddings.weight |
|
|
|
|
|
att_span = self.pos_ebd_size |
|
|
rel_embeddings = rel_embeddings[ |
|
|
self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, : |
|
|
].unsqueeze(0) |
|
|
rel_embeddings = self.pos_dropout(rel_embeddings) |
|
|
|
|
|
score = torch.zeros_like(query_layer[:, :, :1]).expand( |
|
|
-1, -1, key_layer.size(-2) |
|
|
) |
|
|
|
|
|
|
|
|
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) |
|
|
c2p_pos = c2p_pos.squeeze(0).expand( |
|
|
query_layer.size(0), query_layer.size(1), relative_pos.size(-1) |
|
|
) |
|
|
|
|
|
|
|
|
if "c2p" in self.pos_att_type: |
|
|
pos_key_layer = ( |
|
|
self.pos_key_proj(rel_embeddings) |
|
|
if not self.share_att_key |
|
|
else self.key_proj(rel_embeddings) |
|
|
) |
|
|
pos_key_layer = self.transpose_for_scores(pos_key_layer).repeat( |
|
|
batch_size, 1, 1 |
|
|
) |
|
|
|
|
|
c2p_scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
|
|
c2p_att = torch.bmm( |
|
|
query_layer, pos_key_layer.transpose(-1, -2) * c2p_scale |
|
|
) |
|
|
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos) |
|
|
score = score + c2p_att |
|
|
|
|
|
|
|
|
if "p2c" in self.pos_att_type: |
|
|
pos_query_layer = ( |
|
|
self.pos_query_proj(rel_embeddings) |
|
|
if not self.share_att_key |
|
|
else self.query_proj(rel_embeddings) |
|
|
) |
|
|
pos_query_layer = self.transpose_for_scores(pos_query_layer).repeat( |
|
|
batch_size, 1, 1 |
|
|
) |
|
|
|
|
|
p2c_scale = 1.0 / math.sqrt(self.head_size * scale_factor) |
|
|
p2c_att = torch.bmm( |
|
|
pos_query_layer * p2c_scale, key_layer.transpose(-1, -2) |
|
|
) |
|
|
p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos) |
|
|
score = score + p2c_att |
|
|
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HELMBertEmbeddings(nn.Module): |
|
|
"""Token and position embeddings for HELM-BERT.""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding( |
|
|
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
|
|
) |
|
|
self.position_embeddings = nn.Embedding( |
|
|
config.max_position_embeddings, config.hidden_size |
|
|
) |
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Forward pass. |
|
|
|
|
|
Returns: |
|
|
Tuple of (token_embeddings, position_embeddings) |
|
|
""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
embeddings = self.word_embeddings(input_ids) |
|
|
|
|
|
|
|
|
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) |
|
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
position_embeds = self.position_embeddings(position_ids) |
|
|
|
|
|
|
|
|
embeddings = masked_layer_norm(self.layer_norm, embeddings, attention_mask) |
|
|
embeddings = self.dropout(embeddings) |
|
|
|
|
|
return embeddings, position_embeds |
|
|
|
|
|
|
|
|
class NgieLayer(nn.Module): |
|
|
"""n-gram Induced Input Encoding (nGiE) layer. |
|
|
|
|
|
Captures local n-gram patterns using 1D convolution. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=config.hidden_size, |
|
|
out_channels=config.hidden_size, |
|
|
kernel_size=config.ngie_kernel_size, |
|
|
padding=(config.ngie_kernel_size - 1) // 2, |
|
|
groups=1, |
|
|
) |
|
|
self.activation = nn.Tanh() |
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.ngie_dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
residual_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
hidden_states: Input to convolution (batch, seq, hidden) |
|
|
residual_states: States for residual connection (batch, seq, hidden) |
|
|
attention_mask: Mask where 1 = valid, 0 = padding |
|
|
|
|
|
Returns: |
|
|
Output with n-gram information incorporated |
|
|
""" |
|
|
|
|
|
out = ( |
|
|
self.conv(hidden_states.permute(0, 2, 1).contiguous()) |
|
|
.permute(0, 2, 1) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
|
|
|
if version.Version(torch.__version__) >= version.Version("1.2.0a"): |
|
|
rmask = (1 - attention_mask).bool() |
|
|
else: |
|
|
rmask = (1 - attention_mask).byte() |
|
|
|
|
|
|
|
|
out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) |
|
|
|
|
|
|
|
|
out = self.activation(self.dropout(out)) |
|
|
|
|
|
|
|
|
output_states = masked_layer_norm( |
|
|
self.layer_norm, residual_states + out, attention_mask |
|
|
) |
|
|
|
|
|
return output_states |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Transformer block with disentangled attention and GELU FFN.""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.self_attn = DisentangledSelfAttention(config) |
|
|
self.attn_output_dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
|
|
|
|
|
self.linear1 = nn.Sequential( |
|
|
nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU() |
|
|
) |
|
|
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(config.hidden_size) |
|
|
self.norm2 = nn.LayerNorm(config.hidden_size) |
|
|
self.dropout1 = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.dropout2 = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src: torch.Tensor, |
|
|
src_key_padding_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
query_states: Optional[torch.Tensor] = None, |
|
|
relative_pos: Optional[torch.Tensor] = None, |
|
|
rel_embeddings: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
src: Input embeddings [seq_len, batch, hidden] |
|
|
src_key_padding_mask: Padding mask [batch, seq_len] |
|
|
output_attentions: Whether to return attention weights |
|
|
query_states: Optional query for EMD |
|
|
relative_pos: Relative position indices |
|
|
rel_embeddings: Relative position embeddings |
|
|
|
|
|
Returns: |
|
|
Tuple of (output, optional attention weights) |
|
|
""" |
|
|
|
|
|
src_transposed = src.transpose(0, 1) |
|
|
|
|
|
|
|
|
attention_mask = None |
|
|
if src_key_padding_mask is not None: |
|
|
attention_mask = (~src_key_padding_mask).float() |
|
|
|
|
|
query_states_transposed = None |
|
|
if query_states is not None: |
|
|
query_states_transposed = query_states.transpose(0, 1) |
|
|
|
|
|
|
|
|
attn_result = self.self_attn( |
|
|
src_transposed, |
|
|
attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
query_states=query_states_transposed, |
|
|
relative_pos=relative_pos, |
|
|
rel_embeddings=rel_embeddings, |
|
|
) |
|
|
attn_output = attn_result["hidden_states"].transpose(0, 1) |
|
|
attn_weights = attn_result.get("attention_probs") if output_attentions else None |
|
|
|
|
|
|
|
|
attn_output = self.attn_output_dense(attn_output) |
|
|
|
|
|
|
|
|
residual_input = query_states if query_states is not None else src |
|
|
src = residual_input + self.dropout1(attn_output) |
|
|
|
|
|
|
|
|
src = src.transpose(0, 1) |
|
|
src = masked_layer_norm(self.norm1, src) |
|
|
src = src.transpose(0, 1) |
|
|
|
|
|
|
|
|
ff_output = self.linear1(src) |
|
|
ff_output = self.linear2(ff_output) |
|
|
ff_output = self.dropout2(ff_output) |
|
|
src = src + ff_output |
|
|
|
|
|
|
|
|
src = src.transpose(0, 1) |
|
|
src = masked_layer_norm(self.norm2, src) |
|
|
src = src.transpose(0, 1) |
|
|
|
|
|
return src, attn_weights |
|
|
|
|
|
|
|
|
class HELMBertEncoder(nn.Module): |
|
|
"""Stack of transformer blocks with nGiE layer.""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.ngie_layer = NgieLayer(config) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[TransformerBlock(config) for _ in range(config.num_hidden_layers)] |
|
|
) |
|
|
|
|
|
def get_rel_embedding(self) -> Optional[torch.Tensor]: |
|
|
"""Get relative position embeddings from first layer.""" |
|
|
if len(self.layers) > 0: |
|
|
first_layer = self.layers[0] |
|
|
if hasattr(first_layer, "self_attn") and hasattr( |
|
|
first_layer.self_attn, "rel_embeddings" |
|
|
): |
|
|
return first_layer.self_attn.rel_embeddings.weight |
|
|
return None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_embeddings: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
use_emd: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], Optional[Tuple]]: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
hidden_states: Input embeddings [batch, seq, hidden] |
|
|
attention_mask: Attention mask [batch, seq] |
|
|
position_embeddings: Position embeddings for EMD |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return all hidden states |
|
|
use_emd: Whether to use Enhanced Mask Decoder |
|
|
|
|
|
Returns: |
|
|
Tuple of (last_hidden_state, emd_output, all_hidden_states, all_attentions) |
|
|
""" |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
|
|
|
ngie_input_states = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1) |
|
|
|
|
|
|
|
|
key_padding_mask = None |
|
|
if attention_mask is not None: |
|
|
key_padding_mask = ~attention_mask.bool() |
|
|
|
|
|
|
|
|
layer_minus_2 = None |
|
|
num_layers = len(self.layers) |
|
|
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) |
|
|
|
|
|
hidden_states, attn_weights = layer( |
|
|
hidden_states, |
|
|
src_key_padding_mask=key_padding_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
if output_attentions and attn_weights is not None: |
|
|
all_attentions = all_attentions + (attn_weights,) |
|
|
|
|
|
|
|
|
if layer_idx == 0: |
|
|
hidden_states_batch = hidden_states.transpose(0, 1) |
|
|
hidden_states_batch = self.ngie_layer( |
|
|
ngie_input_states, hidden_states_batch, attention_mask |
|
|
) |
|
|
hidden_states = hidden_states_batch.transpose(0, 1) |
|
|
|
|
|
|
|
|
if use_emd and layer_idx == num_layers - 2: |
|
|
layer_minus_2 = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.transpose(0, 1) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
|
|
|
emd_output = None |
|
|
if use_emd and layer_minus_2 is not None and position_embeddings is not None: |
|
|
emd_keys_values = layer_minus_2 |
|
|
emd_query = layer_minus_2.transpose(0, 1) |
|
|
emd_query = position_embeddings + emd_query |
|
|
emd_query = emd_query.transpose(0, 1) |
|
|
|
|
|
rel_embeddings = self.get_rel_embedding() |
|
|
last_layer = self.layers[-1] |
|
|
|
|
|
for _ in range(2): |
|
|
emd_query, _ = last_layer( |
|
|
emd_keys_values, |
|
|
src_key_padding_mask=key_padding_mask, |
|
|
query_states=emd_query, |
|
|
relative_pos=None, |
|
|
rel_embeddings=rel_embeddings, |
|
|
) |
|
|
|
|
|
emd_output = emd_query.transpose(0, 1) |
|
|
|
|
|
return hidden_states, emd_output, all_hidden_states, all_attentions |
|
|
|
|
|
|
|
|
class HELMBertPooler(nn.Module): |
|
|
"""Mean pooling over sequence.""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
"""Apply mean pooling. |
|
|
|
|
|
Args: |
|
|
hidden_states: [batch, seq, hidden] |
|
|
attention_mask: [batch, seq] |
|
|
|
|
|
Returns: |
|
|
Pooled output [batch, hidden] |
|
|
""" |
|
|
if attention_mask is not None: |
|
|
mask_expanded = ( |
|
|
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
|
|
) |
|
|
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) |
|
|
eps = torch.finfo(hidden_states.dtype).eps |
|
|
sum_mask = torch.clamp(mask_expanded.sum(1), min=eps) |
|
|
return sum_embeddings / sum_mask |
|
|
else: |
|
|
return hidden_states.mean(dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HELMBertPreTrainedModel(PreTrainedModel): |
|
|
"""Base class for HELM-BERT models.""" |
|
|
|
|
|
config_class = HELMBertConfig |
|
|
base_model_prefix = "helmbert" |
|
|
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
"""Initialize weights with BERT-style initialization.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HELMBertModel(HELMBertPreTrainedModel): |
|
|
"""HELM-BERT base model. |
|
|
|
|
|
This model outputs the last hidden states and optionally pooled output. |
|
|
|
|
|
Example: |
|
|
>>> from helmbert import HELMBertModel, HELMBertTokenizer |
|
|
>>> tokenizer = HELMBertTokenizer() |
|
|
>>> model = HELMBertModel.from_pretrained("./checkpoints/helmbert-base") |
|
|
>>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt") |
|
|
>>> outputs = model(**inputs) |
|
|
>>> last_hidden_state = outputs.last_hidden_state |
|
|
>>> pooler_output = outputs.pooler_output |
|
|
""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embeddings = HELMBertEmbeddings(config) |
|
|
self.encoder = HELMBertEncoder(config) |
|
|
self.pooler = HELMBertPooler(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
|
return self.embeddings.word_embeddings |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None: |
|
|
self.embeddings.word_embeddings = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
attention_mask: Attention mask [batch, seq] |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return all hidden states |
|
|
return_dict: Whether to return a ModelOutput |
|
|
|
|
|
Returns: |
|
|
BaseModelOutputWithPooling or tuple |
|
|
""" |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
embeddings, position_embeddings = self.embeddings(input_ids, attention_mask) |
|
|
|
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
embeddings, |
|
|
attention_mask=attention_mask, |
|
|
position_embeddings=position_embeddings, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
use_emd=False, |
|
|
) |
|
|
|
|
|
last_hidden_state = encoder_outputs[0] |
|
|
hidden_states = encoder_outputs[2] |
|
|
attentions = encoder_outputs[3] |
|
|
|
|
|
|
|
|
pooler_output = self.pooler(last_hidden_state, attention_mask) |
|
|
|
|
|
if not return_dict: |
|
|
return (last_hidden_state, pooler_output, hidden_states, attentions) |
|
|
|
|
|
return BaseModelOutputWithPooling( |
|
|
last_hidden_state=last_hidden_state, |
|
|
pooler_output=pooler_output, |
|
|
hidden_states=hidden_states, |
|
|
attentions=attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class HELMBertLMHead(nn.Module): |
|
|
"""MLM head with weight tying (HuggingFace standard).""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
|
self.activation = nn.GELU() |
|
|
|
|
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
hidden_states: [batch, seq, hidden] |
|
|
|
|
|
Returns: |
|
|
Logits [batch, seq, vocab] |
|
|
""" |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.activation(hidden_states) |
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
logits = self.decoder(hidden_states) |
|
|
return logits |
|
|
|
|
|
|
|
|
class HELMBertForMaskedLM(HELMBertPreTrainedModel): |
|
|
"""HELM-BERT for Masked Language Modeling with Enhanced Mask Decoder (EMD). |
|
|
|
|
|
Example: |
|
|
>>> from helmbert import HELMBertForMaskedLM, HELMBertTokenizer |
|
|
>>> tokenizer = HELMBertTokenizer() |
|
|
>>> model = HELMBertForMaskedLM.from_pretrained("./checkpoints/helmbert-base") |
|
|
>>> inputs = tokenizer("PEPTIDE1{A.¶.D.E}$$$$", return_tensors="pt") # ¶ is mask |
|
|
>>> outputs = model(**inputs) |
|
|
>>> predictions = outputs.logits.argmax(dim=-1) |
|
|
""" |
|
|
|
|
|
_tied_weights_keys = ["lm_head.decoder.weight"] |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__(config) |
|
|
self.helmbert = HELMBertModel(config) |
|
|
self.lm_head = HELMBertLMHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self) -> nn.Linear: |
|
|
return self.lm_head.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
|
|
self.lm_head.decoder = new_embeddings |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Embedding: |
|
|
return self.helmbert.embeddings.word_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
use_emd: bool = True, |
|
|
) -> Union[Tuple, MaskedLMOutput]: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
attention_mask: Attention mask [batch, seq] |
|
|
labels: Labels for MLM (-100 for non-masked tokens) |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return all hidden states |
|
|
return_dict: Whether to return a ModelOutput |
|
|
use_emd: Whether to use Enhanced Mask Decoder |
|
|
|
|
|
Returns: |
|
|
MaskedLMOutput or tuple |
|
|
""" |
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
embeddings, position_embeddings = self.helmbert.embeddings( |
|
|
input_ids, attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
encoder_outputs = self.helmbert.encoder( |
|
|
embeddings, |
|
|
attention_mask=attention_mask, |
|
|
position_embeddings=position_embeddings, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
use_emd=use_emd, |
|
|
) |
|
|
|
|
|
|
|
|
if use_emd and encoder_outputs[1] is not None: |
|
|
sequence_output = encoder_outputs[1] |
|
|
else: |
|
|
sequence_output = encoder_outputs[0] |
|
|
|
|
|
hidden_states = encoder_outputs[2] |
|
|
attentions = encoder_outputs[3] |
|
|
|
|
|
|
|
|
prediction_scores = self.lm_head(sequence_output) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
loss = loss_fct( |
|
|
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores, hidden_states, attentions) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=loss, |
|
|
logits=prediction_scores, |
|
|
hidden_states=hidden_states, |
|
|
attentions=attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class MLPHead(nn.Module): |
|
|
"""MLP head with skip connections for classification/regression. |
|
|
|
|
|
Architecture: input -> [Linear -> GELU -> LayerNorm -> Dropout (+ skip)] x N -> Linear -> output |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
output_dim: int, |
|
|
hidden_dims: list, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList() |
|
|
self.norms = nn.ModuleList() |
|
|
self.dropouts = nn.ModuleList() |
|
|
|
|
|
prev_dim = input_dim |
|
|
for hidden_dim in hidden_dims: |
|
|
self.layers.append(nn.Linear(prev_dim, hidden_dim)) |
|
|
self.norms.append(nn.LayerNorm(hidden_dim)) |
|
|
self.dropouts.append(nn.Dropout(dropout)) |
|
|
prev_dim = hidden_dim |
|
|
|
|
|
self.output_layer = nn.Linear(prev_dim, output_dim) |
|
|
self.activation = nn.GELU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
for layer, norm, dropout in zip(self.layers, self.norms, self.dropouts): |
|
|
identity = x |
|
|
x = layer(x) |
|
|
if x.shape == identity.shape: |
|
|
x = x + identity |
|
|
x = self.activation(x) |
|
|
x = norm(x) |
|
|
x = dropout(x) |
|
|
return self.output_layer(x) |
|
|
|
|
|
|
|
|
class HELMBertForSequenceClassification(HELMBertPreTrainedModel): |
|
|
"""HELM-BERT for sequence classification/regression. |
|
|
|
|
|
Example: |
|
|
>>> from helmbert import HELMBertForSequenceClassification, HELMBertConfig |
|
|
>>> # Simple linear head (default) |
|
|
>>> config = HELMBertConfig(num_labels=1) |
|
|
>>> model = HELMBertForSequenceClassification(config) |
|
|
>>> |
|
|
>>> # MLP head with 2 layers (for permeability prediction) |
|
|
>>> config = HELMBertConfig(num_labels=1, classifier_num_layers=2) |
|
|
>>> model = HELMBertForSequenceClassification(config) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: HELMBertConfig): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.config = config |
|
|
|
|
|
self.helmbert = HELMBertModel(config) |
|
|
|
|
|
|
|
|
if config.classifier_num_layers > 0: |
|
|
hidden_dims = [config.hidden_size] * config.classifier_num_layers |
|
|
self.classifier = MLPHead( |
|
|
input_dim=config.hidden_size, |
|
|
output_dim=config.num_labels, |
|
|
hidden_dims=hidden_dims, |
|
|
dropout=config.classifier_dropout, |
|
|
) |
|
|
else: |
|
|
self.dropout = nn.Dropout(config.classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
|
"""Forward pass. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
attention_mask: Attention mask [batch, seq] |
|
|
labels: Labels for classification/regression |
|
|
output_attentions: Whether to return attention weights |
|
|
output_hidden_states: Whether to return all hidden states |
|
|
return_dict: Whether to return a ModelOutput |
|
|
|
|
|
Returns: |
|
|
SequenceClassifierOutput or tuple |
|
|
""" |
|
|
outputs = self.helmbert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
pooled_output = outputs.pooler_output |
|
|
|
|
|
if hasattr(self, "dropout"): |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.problem_type is None: |
|
|
if self.num_labels == 1: |
|
|
self.config.problem_type = "regression" |
|
|
elif self.num_labels > 1 and ( |
|
|
labels.dtype == torch.long or labels.dtype == torch.int |
|
|
): |
|
|
self.config.problem_type = "single_label_classification" |
|
|
else: |
|
|
self.config.problem_type = "multi_label_classification" |
|
|
|
|
|
if self.config.problem_type == "regression": |
|
|
loss_fct = nn.MSELoss() |
|
|
if self.num_labels == 1: |
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(logits, labels) |
|
|
elif self.config.problem_type == "single_label_classification": |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.config.problem_type == "multi_label_classification": |
|
|
loss_fct = nn.BCEWithLogitsLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|