"""HELM-BERT model implementation. This module implements the HELM-BERT model with: - Disentangled attention (DeBERTa-style) - Enhanced Mask Decoder (EMD) for MLM - n-gram Induced Encoding (nGiE) layer """ import math from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn from packaging import version from torch import _softmax_backward_data from transformers import PreTrainedModel from transformers.modeling_outputs import ( BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput, ) from .configuration_helmbert import HELMBertConfig # ----------------------------------------------------------------------------- # Utility Functions # ----------------------------------------------------------------------------- def masked_layer_norm( layer_norm: nn.LayerNorm, x: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Apply LayerNorm with masking to avoid updates on padding tokens. Args: layer_norm: LayerNorm module x: Input tensor (batch_size, seq_len, hidden_size) mask: Mask tensor where 0 = padding (ignored), 1 = valid token Returns: Normalized tensor with padding positions zeroed out """ output = layer_norm(x).to(x.dtype) if mask is None: return output if mask.dim() != x.dim(): if mask.dim() == 4: mask = mask.squeeze(1).squeeze(1) mask = mask.unsqueeze(2) mask = mask.to(output.dtype) return output * mask class XSoftmax(torch.autograd.Function): """Masked Softmax optimized for memory efficiency.""" @staticmethod def forward( ctx, input: torch.Tensor, mask: Optional[torch.Tensor], dim: int ) -> torch.Tensor: ctx.dim = dim if mask is not None: rmask = ~(mask.bool()) if rmask.dim() == 2: rmask = rmask.unsqueeze(1).unsqueeze(2) elif rmask.dim() == 3: rmask = rmask.unsqueeze(2) output = input.masked_fill(rmask, float("-inf")) else: output = input output = torch.softmax(output, ctx.dim) if mask is not None: output.masked_fill_(rmask, 0) ctx.save_for_backward(output) return output @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: (output,) = ctx.saved_tensors if version.Version(torch.__version__) >= version.Version("1.11.0"): input_grad = _softmax_backward_data( grad_output, output, ctx.dim, output.dtype ) else: input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output) return input_grad, None, None def build_relative_position( query_size: int, key_size: int, bucket_size: int = -1, max_position: int = 512, device: Optional[torch.device] = None, ) -> torch.Tensor: """Build relative position matrix with optional log-bucketing.""" q_ids = torch.arange(query_size, dtype=torch.long, device=device) k_ids = torch.arange(key_size, dtype=torch.long, device=device) rel_pos = q_ids.unsqueeze(1) - k_ids.unsqueeze(0) if bucket_size > 0: rel_buckets = 0 num_buckets = bucket_size rel_buckets += (rel_pos > 0).long() * (num_buckets // 2) rel_pos = torch.abs(rel_pos) max_exact = num_buckets // 4 is_small = rel_pos < max_exact rel_pos_if_large = ( max_exact + ( torch.log(rel_pos.float() / max_exact) / math.log(max_position / max_exact) * (num_buckets // 4 - 1) ).long() ) rel_pos_if_large = torch.min( rel_pos_if_large, torch.full_like(rel_pos_if_large, num_buckets // 2 - 1) ) rel_buckets += torch.where(is_small, rel_pos, rel_pos_if_large) return rel_buckets else: rel_pos = torch.clamp(rel_pos, -max_position, max_position) return rel_pos + max_position # ----------------------------------------------------------------------------- # Attention Modules # ----------------------------------------------------------------------------- class DisentangledSelfAttention(nn.Module): """Disentangled self-attention with content and position separation. Implements content-to-content, content-to-position, and position-to-content attention as described in DeBERTa. """ def __init__(self, config: HELMBertConfig): super().__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( f"hidden_size ({config.hidden_size}) must be divisible by " f"num_attention_heads ({config.num_attention_heads})" ) self.num_heads = config.num_attention_heads self.head_size = config.hidden_size // config.num_attention_heads self.all_head_size = self.num_heads * self.head_size # Content projections self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) # Position attention configuration self.pos_att_type = [x.strip() for x in config.pos_att_type.lower().split("|")] self.max_relative_positions = config.max_relative_positions self.position_buckets = config.position_buckets self.share_att_key = config.share_att_key # Position embedding size self.pos_ebd_size = config.max_relative_positions if config.position_buckets > 0: self.pos_ebd_size = config.position_buckets # Position embeddings self.rel_embeddings = nn.Embedding(self.pos_ebd_size * 2, config.hidden_size) # Position projections if not self.share_att_key: if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: self.pos_key_proj = nn.Linear( config.hidden_size, self.all_head_size, bias=True ) if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: self.pos_query_proj = nn.Linear( config.hidden_size, self.all_head_size, bias=False ) # Dropout self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.pos_dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: """Reshape tensor for attention computation.""" new_shape = x.size()[:-1] + (self.num_heads, self.head_size) x = x.view(*new_shape) return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, query_states: Optional[torch.Tensor] = None, relative_pos: Optional[torch.Tensor] = None, rel_embeddings: Optional[torch.Tensor] = None, ) -> Dict[str, Any]: """Forward pass of disentangled attention.""" if query_states is None: query_states = hidden_states # Compute Q, K, V query_layer = self.transpose_for_scores(self.query_proj(query_states)).float() key_layer = self.transpose_for_scores(self.key_proj(hidden_states)).float() value_layer = self.transpose_for_scores(self.value_proj(hidden_states)) # Calculate scale factor scale_factor = 1 if "c2p" in self.pos_att_type: scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 if "p2p" in self.pos_att_type: scale_factor += 1 scale = 1.0 / math.sqrt(self.head_size * scale_factor) # Content-to-content attention (c2c) c2c_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) * scale) attention_scores = c2c_scores # Add relative position bias if enabled if len(self.pos_att_type) > 0 and self.pos_att_type[0]: rel_att = self._disentangled_attention_bias( query_layer, key_layer, relative_pos, rel_embeddings, scale_factor ) if rel_att is not None: attention_scores = attention_scores + rel_att # Normalize scores for numerical stability attention_scores = ( attention_scores - attention_scores.max(dim=-1, keepdim=True)[0].detach() ) attention_scores = attention_scores.to(hidden_states.dtype) # Reshape for XSoftmax attention_scores = attention_scores.view( -1, self.num_heads, attention_scores.size(-2), attention_scores.size(-1) ) # Apply XSoftmax attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = self.dropout(attention_probs) # Apply attention to values attention_probs_flat = attention_probs.view( -1, attention_probs.size(-2), attention_probs.size(-1) ) context_layer = torch.bmm(attention_probs_flat, value_layer) # Reshape output context_layer = context_layer.view( -1, self.num_heads, context_layer.size(-2), context_layer.size(-1) ) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_shape) result = {"hidden_states": context_layer, "attention_probs": attention_probs} return result def _disentangled_attention_bias( self, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos: Optional[torch.Tensor], rel_embeddings: Optional[torch.Tensor], scale_factor: int, ) -> Optional[torch.Tensor]: """Compute disentangled attention bias.""" if relative_pos is None: q_size = query_layer.size(-2) k_size = key_layer.size(-2) relative_pos = build_relative_position( q_size, k_size, bucket_size=self.position_buckets, max_position=self.max_relative_positions, device=query_layer.device, ) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: relative_pos = relative_pos.unsqueeze(1) batch_size = query_layer.size(0) // self.num_heads # Get position embeddings if rel_embeddings is None: rel_embeddings = self.rel_embeddings.weight att_span = self.pos_ebd_size rel_embeddings = rel_embeddings[ self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, : ].unsqueeze(0) rel_embeddings = self.pos_dropout(rel_embeddings) score = torch.zeros_like(query_layer[:, :, :1]).expand( -1, -1, key_layer.size(-2) ) # Prepare position indices c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = c2p_pos.squeeze(0).expand( query_layer.size(0), query_layer.size(1), relative_pos.size(-1) ) # Content-to-position (c2p) if "c2p" in self.pos_att_type: pos_key_layer = ( self.pos_key_proj(rel_embeddings) if not self.share_att_key else self.key_proj(rel_embeddings) ) pos_key_layer = self.transpose_for_scores(pos_key_layer).repeat( batch_size, 1, 1 ) c2p_scale = 1.0 / math.sqrt(self.head_size * scale_factor) c2p_att = torch.bmm( query_layer, pos_key_layer.transpose(-1, -2) * c2p_scale ) c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos) score = score + c2p_att # Position-to-content (p2c) if "p2c" in self.pos_att_type: pos_query_layer = ( self.pos_query_proj(rel_embeddings) if not self.share_att_key else self.query_proj(rel_embeddings) ) pos_query_layer = self.transpose_for_scores(pos_query_layer).repeat( batch_size, 1, 1 ) p2c_scale = 1.0 / math.sqrt(self.head_size * scale_factor) p2c_att = torch.bmm( pos_query_layer * p2c_scale, key_layer.transpose(-1, -2) ) p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos) score = score + p2c_att return score # ----------------------------------------------------------------------------- # Transformer Components # ----------------------------------------------------------------------------- class HELMBertEmbeddings(nn.Module): """Token and position embeddings for HELM-BERT.""" def __init__(self, config: HELMBertConfig): super().__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id ) self.position_embeddings = nn.Embedding( config.max_position_embeddings, config.hidden_size ) self.layer_norm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass. Returns: Tuple of (token_embeddings, position_embeddings) """ batch_size, seq_len = input_ids.shape # Token embeddings embeddings = self.word_embeddings(input_ids) # Position embeddings position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) position_embeds = self.position_embeddings(position_ids) # Layer norm and dropout embeddings = masked_layer_norm(self.layer_norm, embeddings, attention_mask) embeddings = self.dropout(embeddings) return embeddings, position_embeds class NgieLayer(nn.Module): """n-gram Induced Input Encoding (nGiE) layer. Captures local n-gram patterns using 1D convolution. """ def __init__(self, config: HELMBertConfig): super().__init__() self.conv = nn.Conv1d( in_channels=config.hidden_size, out_channels=config.hidden_size, kernel_size=config.ngie_kernel_size, padding=(config.ngie_kernel_size - 1) // 2, groups=1, ) self.activation = nn.Tanh() self.layer_norm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.ngie_dropout) def forward( self, hidden_states: torch.Tensor, residual_states: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """Forward pass. Args: hidden_states: Input to convolution (batch, seq, hidden) residual_states: States for residual connection (batch, seq, hidden) attention_mask: Mask where 1 = valid, 0 = padding Returns: Output with n-gram information incorporated """ # Apply 1D convolution out = ( self.conv(hidden_states.permute(0, 2, 1).contiguous()) .permute(0, 2, 1) .contiguous() ) # Create reverse mask for padding if version.Version(torch.__version__) >= version.Version("1.2.0a"): rmask = (1 - attention_mask).bool() else: rmask = (1 - attention_mask).byte() # Zero out padding positions out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) # Apply activation and dropout out = self.activation(self.dropout(out)) # Residual connection with LayerNorm output_states = masked_layer_norm( self.layer_norm, residual_states + out, attention_mask ) return output_states class TransformerBlock(nn.Module): """Transformer block with disentangled attention and GELU FFN.""" def __init__(self, config: HELMBertConfig): super().__init__() self.self_attn = DisentangledSelfAttention(config) self.attn_output_dense = nn.Linear(config.hidden_size, config.hidden_size) # FFN with GELU self.linear1 = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU() ) self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) # Normalization and dropout self.norm1 = nn.LayerNorm(config.hidden_size) self.norm2 = nn.LayerNorm(config.hidden_size) self.dropout1 = nn.Dropout(config.hidden_dropout_prob) self.dropout2 = nn.Dropout(config.hidden_dropout_prob) def forward( self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, query_states: Optional[torch.Tensor] = None, relative_pos: Optional[torch.Tensor] = None, rel_embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward pass. Args: src: Input embeddings [seq_len, batch, hidden] src_key_padding_mask: Padding mask [batch, seq_len] output_attentions: Whether to return attention weights query_states: Optional query for EMD relative_pos: Relative position indices rel_embeddings: Relative position embeddings Returns: Tuple of (output, optional attention weights) """ # Transpose for attention [seq, batch, hidden] -> [batch, seq, hidden] src_transposed = src.transpose(0, 1) # Convert padding mask to attention mask (1=valid, 0=padding) attention_mask = None if src_key_padding_mask is not None: attention_mask = (~src_key_padding_mask).float() query_states_transposed = None if query_states is not None: query_states_transposed = query_states.transpose(0, 1) # Self-attention attn_result = self.self_attn( src_transposed, attention_mask, output_attentions=output_attentions, query_states=query_states_transposed, relative_pos=relative_pos, rel_embeddings=rel_embeddings, ) attn_output = attn_result["hidden_states"].transpose(0, 1) attn_weights = attn_result.get("attention_probs") if output_attentions else None # Dense projection attn_output = self.attn_output_dense(attn_output) # Residual connection residual_input = query_states if query_states is not None else src src = residual_input + self.dropout1(attn_output) # LayerNorm src = src.transpose(0, 1) src = masked_layer_norm(self.norm1, src) src = src.transpose(0, 1) # FFN ff_output = self.linear1(src) ff_output = self.linear2(ff_output) ff_output = self.dropout2(ff_output) src = src + ff_output # LayerNorm src = src.transpose(0, 1) src = masked_layer_norm(self.norm2, src) src = src.transpose(0, 1) return src, attn_weights class HELMBertEncoder(nn.Module): """Stack of transformer blocks with nGiE layer.""" def __init__(self, config: HELMBertConfig): super().__init__() self.config = config # nGiE layer (applied after first transformer block) self.ngie_layer = NgieLayer(config) # Transformer blocks self.layers = nn.ModuleList( [TransformerBlock(config) for _ in range(config.num_hidden_layers)] ) def get_rel_embedding(self) -> Optional[torch.Tensor]: """Get relative position embeddings from first layer.""" if len(self.layers) > 0: first_layer = self.layers[0] if hasattr(first_layer, "self_attn") and hasattr( first_layer.self_attn, "rel_embeddings" ): return first_layer.self_attn.rel_embeddings.weight return None def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, use_emd: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], Optional[Tuple]]: """Forward pass. Args: hidden_states: Input embeddings [batch, seq, hidden] attention_mask: Attention mask [batch, seq] position_embeddings: Position embeddings for EMD output_attentions: Whether to return attention weights output_hidden_states: Whether to return all hidden states use_emd: Whether to use Enhanced Mask Decoder Returns: Tuple of (last_hidden_state, emd_output, all_hidden_states, all_attentions) """ all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # Store for nGiE ngie_input_states = hidden_states # [batch, seq, hidden] -> [seq, batch, hidden] hidden_states = hidden_states.transpose(0, 1) # Key padding mask (True = padding) key_padding_mask = None if attention_mask is not None: key_padding_mask = ~attention_mask.bool() # Store layer[-2] for EMD layer_minus_2 = None num_layers = len(self.layers) for layer_idx, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) hidden_states, attn_weights = layer( hidden_states, src_key_padding_mask=key_padding_mask, output_attentions=output_attentions, ) if output_attentions and attn_weights is not None: all_attentions = all_attentions + (attn_weights,) # Apply nGiE after first layer if layer_idx == 0: hidden_states_batch = hidden_states.transpose(0, 1) hidden_states_batch = self.ngie_layer( ngie_input_states, hidden_states_batch, attention_mask ) hidden_states = hidden_states_batch.transpose(0, 1) # Store layer[-2] for EMD if use_emd and layer_idx == num_layers - 2: layer_minus_2 = hidden_states # Convert back to [batch, seq, hidden] hidden_states = hidden_states.transpose(0, 1) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # Enhanced Mask Decoder (EMD) for MLM emd_output = None if use_emd and layer_minus_2 is not None and position_embeddings is not None: emd_keys_values = layer_minus_2 emd_query = layer_minus_2.transpose(0, 1) emd_query = position_embeddings + emd_query emd_query = emd_query.transpose(0, 1) rel_embeddings = self.get_rel_embedding() last_layer = self.layers[-1] for _ in range(2): emd_query, _ = last_layer( emd_keys_values, src_key_padding_mask=key_padding_mask, query_states=emd_query, relative_pos=None, rel_embeddings=rel_embeddings, ) emd_output = emd_query.transpose(0, 1) return hidden_states, emd_output, all_hidden_states, all_attentions class HELMBertPooler(nn.Module): """Mean pooling over sequence.""" def __init__(self, config: HELMBertConfig): super().__init__() self.hidden_size = config.hidden_size def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Apply mean pooling. Args: hidden_states: [batch, seq, hidden] attention_mask: [batch, seq] Returns: Pooled output [batch, hidden] """ if attention_mask is not None: mask_expanded = ( attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() ) sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) eps = torch.finfo(hidden_states.dtype).eps sum_mask = torch.clamp(mask_expanded.sum(1), min=eps) return sum_embeddings / sum_mask else: return hidden_states.mean(dim=1) # ----------------------------------------------------------------------------- # Pre-trained Model Base # ----------------------------------------------------------------------------- class HELMBertPreTrainedModel(PreTrainedModel): """Base class for HELM-BERT models.""" config_class = HELMBertConfig base_model_prefix = "helmbert" def _init_weights(self, module: nn.Module) -> None: """Initialize weights with BERT-style initialization.""" if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) # ----------------------------------------------------------------------------- # Model Classes # ----------------------------------------------------------------------------- class HELMBertModel(HELMBertPreTrainedModel): """HELM-BERT base model. This model outputs the last hidden states and optionally pooled output. Example: >>> from helmbert import HELMBertModel, HELMBertTokenizer >>> tokenizer = HELMBertTokenizer() >>> model = HELMBertModel.from_pretrained("./checkpoints/helmbert-base") >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooler_output = outputs.pooler_output """ def __init__(self, config: HELMBertConfig): super().__init__(config) self.config = config self.embeddings = HELMBertEmbeddings(config) self.encoder = HELMBertEncoder(config) self.pooler = HELMBertPooler(config) self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.embeddings.word_embeddings def set_input_embeddings(self, value: nn.Embedding) -> None: self.embeddings.word_embeddings = value def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[Tuple, BaseModelOutputWithPooling]: """Forward pass. Args: input_ids: Token IDs [batch, seq] attention_mask: Attention mask [batch, seq] output_attentions: Whether to return attention weights output_hidden_states: Whether to return all hidden states return_dict: Whether to return a ModelOutput Returns: BaseModelOutputWithPooling or tuple """ if attention_mask is None: attention_mask = torch.ones_like(input_ids) # Embeddings embeddings, position_embeddings = self.embeddings(input_ids, attention_mask) # Encoder encoder_outputs = self.encoder( embeddings, attention_mask=attention_mask, position_embeddings=position_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_emd=False, ) last_hidden_state = encoder_outputs[0] hidden_states = encoder_outputs[2] attentions = encoder_outputs[3] # Pooling pooler_output = self.pooler(last_hidden_state, attention_mask) if not return_dict: return (last_hidden_state, pooler_output, hidden_states, attentions) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=hidden_states, attentions=attentions, ) class HELMBertLMHead(nn.Module): """MLM head with weight tying (HuggingFace standard).""" def __init__(self, config: HELMBertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size) self.activation = nn.GELU() # Decoder with weight tying (weight tied to embedding, bias is separate) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Forward pass. Args: hidden_states: [batch, seq, hidden] Returns: Logits [batch, seq, vocab] """ hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.layer_norm(hidden_states) logits = self.decoder(hidden_states) return logits class HELMBertForMaskedLM(HELMBertPreTrainedModel): """HELM-BERT for Masked Language Modeling with Enhanced Mask Decoder (EMD). Example: >>> from helmbert import HELMBertForMaskedLM, HELMBertTokenizer >>> tokenizer = HELMBertTokenizer() >>> model = HELMBertForMaskedLM.from_pretrained("./checkpoints/helmbert-base") >>> inputs = tokenizer("PEPTIDE1{A.¶.D.E}$$$$", return_tensors="pt") # ¶ is mask >>> outputs = model(**inputs) >>> predictions = outputs.logits.argmax(dim=-1) """ _tied_weights_keys = ["lm_head.decoder.weight"] def __init__(self, config: HELMBertConfig): super().__init__(config) self.helmbert = HELMBertModel(config) self.lm_head = HELMBertLMHead(config) self.post_init() def get_output_embeddings(self) -> nn.Linear: return self.lm_head.decoder def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: self.lm_head.decoder = new_embeddings def get_input_embeddings(self) -> nn.Embedding: return self.helmbert.embeddings.word_embeddings def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, use_emd: bool = True, ) -> Union[Tuple, MaskedLMOutput]: """Forward pass. Args: input_ids: Token IDs [batch, seq] attention_mask: Attention mask [batch, seq] labels: Labels for MLM (-100 for non-masked tokens) output_attentions: Whether to return attention weights output_hidden_states: Whether to return all hidden states return_dict: Whether to return a ModelOutput use_emd: Whether to use Enhanced Mask Decoder Returns: MaskedLMOutput or tuple """ if attention_mask is None: attention_mask = torch.ones_like(input_ids) # Embeddings embeddings, position_embeddings = self.helmbert.embeddings( input_ids, attention_mask ) # Encoder with optional EMD encoder_outputs = self.helmbert.encoder( embeddings, attention_mask=attention_mask, position_embeddings=position_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_emd=use_emd, ) # Use EMD output if available, otherwise use last hidden state if use_emd and encoder_outputs[1] is not None: sequence_output = encoder_outputs[1] else: sequence_output = encoder_outputs[0] hidden_states = encoder_outputs[2] attentions = encoder_outputs[3] # MLM head (weight tying handled by HuggingFace) prediction_scores = self.lm_head(sequence_output) # Calculate loss if labels provided loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) ) if not return_dict: output = (prediction_scores, hidden_states, attentions) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=prediction_scores, hidden_states=hidden_states, attentions=attentions, ) class MLPHead(nn.Module): """MLP head with skip connections for classification/regression. Architecture: input -> [Linear -> GELU -> LayerNorm -> Dropout (+ skip)] x N -> Linear -> output """ def __init__( self, input_dim: int, output_dim: int, hidden_dims: list, dropout: float = 0.1, ): super().__init__() self.layers = nn.ModuleList() self.norms = nn.ModuleList() self.dropouts = nn.ModuleList() prev_dim = input_dim for hidden_dim in hidden_dims: self.layers.append(nn.Linear(prev_dim, hidden_dim)) self.norms.append(nn.LayerNorm(hidden_dim)) self.dropouts.append(nn.Dropout(dropout)) prev_dim = hidden_dim self.output_layer = nn.Linear(prev_dim, output_dim) self.activation = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: for layer, norm, dropout in zip(self.layers, self.norms, self.dropouts): identity = x x = layer(x) if x.shape == identity.shape: x = x + identity # Skip connection x = self.activation(x) x = norm(x) x = dropout(x) return self.output_layer(x) class HELMBertForSequenceClassification(HELMBertPreTrainedModel): """HELM-BERT for sequence classification/regression. Example: >>> from helmbert import HELMBertForSequenceClassification, HELMBertConfig >>> # Simple linear head (default) >>> config = HELMBertConfig(num_labels=1) >>> model = HELMBertForSequenceClassification(config) >>> >>> # MLP head with 2 layers (for permeability prediction) >>> config = HELMBertConfig(num_labels=1, classifier_num_layers=2) >>> model = HELMBertForSequenceClassification(config) """ def __init__(self, config: HELMBertConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.helmbert = HELMBertModel(config) # Use MLP head if num_layers > 0, otherwise simple linear if config.classifier_num_layers > 0: hidden_dims = [config.hidden_size] * config.classifier_num_layers self.classifier = MLPHead( input_dim=config.hidden_size, output_dim=config.num_labels, hidden_dims=hidden_dims, dropout=config.classifier_dropout, ) else: self.dropout = nn.Dropout(config.classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[Tuple, SequenceClassifierOutput]: """Forward pass. Args: input_ids: Token IDs [batch, seq] attention_mask: Attention mask [batch, seq] labels: Labels for classification/regression output_attentions: Whether to return attention weights output_hidden_states: Whether to return all hidden states return_dict: Whether to return a ModelOutput Returns: SequenceClassifierOutput or tuple """ outputs = self.helmbert( input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) pooled_output = outputs.pooler_output # MLP head has internal dropout, simple linear needs separate dropout if hasattr(self, "dropout"): pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )