Flansma commited on
Commit
45723db
·
verified ·
1 Parent(s): 9ace66c

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HELMBertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 1,
7
+ "dtype": "float32",
8
+ "eos_token_id": 2,
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 256,
11
+ "intermediate_size": 3072,
12
+ "mask_token_id": 4,
13
+ "max_position_embeddings": 512,
14
+ "max_relative_positions": 512,
15
+ "model_type": "helmbert",
16
+ "ngie_dropout": 0.1,
17
+ "ngie_kernel_size": 3,
18
+ "num_attention_heads": 4,
19
+ "num_hidden_layers": 2,
20
+ "pad_token_id": 0,
21
+ "pos_att_type": "c2p|p2c",
22
+ "position_buckets": 256,
23
+ "sep_token_id": 2,
24
+ "share_att_key": false,
25
+ "transformers_version": "4.57.3",
26
+ "vocab_size": 78,
27
+ "auto_map": {
28
+ "AutoConfig": "configuration_helmbert.HELMBertConfig",
29
+ "AutoModel": "modeling_helmbert.HELMBertModel",
30
+ "AutoModelForMaskedLM": "modeling_helmbert.HELMBertForMaskedLM",
31
+ "AutoModelForSequenceClassification": "modeling_helmbert.HELMBertForSequenceClassification",
32
+ "AutoTokenizer": "tokenization_helmbert.HELMBertTokenizer"
33
+ }
34
+ }
configuration_helmbert.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HELM-BERT configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class HELMBertConfig(PretrainedConfig):
7
+ """Configuration class for HELM-BERT model.
8
+
9
+ This configuration class stores all the parameters needed to instantiate a HELM-BERT model.
10
+ It inherits from PretrainedConfig and can be used with HuggingFace's from_pretrained and
11
+ save_pretrained methods.
12
+
13
+ Args:
14
+ vocab_size: Size of the vocabulary (default: 78 for HELM character vocabulary)
15
+ hidden_size: Dimensionality of the encoder layers (default: 768)
16
+ num_hidden_layers: Number of transformer layers (default: 6)
17
+ num_attention_heads: Number of attention heads (default: 12)
18
+ intermediate_size: Dimensionality of the feed-forward layer (default: 3072)
19
+ hidden_dropout_prob: Dropout probability for hidden layers (default: 0.1)
20
+ attention_probs_dropout_prob: Dropout probability for attention (default: 0.1)
21
+ max_position_embeddings: Maximum sequence length (default: 512)
22
+ max_relative_positions: Maximum relative position distance (default: 512)
23
+ position_buckets: Number of position buckets for log-bucketing (default: 256)
24
+ pos_att_type: Position attention types, pipe-separated (default: "c2p|p2c")
25
+ share_att_key: Whether to share attention key projections (default: False)
26
+ ngie_kernel_size: Kernel size for nGiE convolution (default: 3)
27
+ ngie_dropout: Dropout for nGiE layer (default: 0.1)
28
+ pad_token_id: ID of padding token (default: 0)
29
+ bos_token_id: ID of beginning-of-sequence token (default: 1)
30
+ eos_token_id: ID of end-of-sequence token (default: 2)
31
+ sep_token_id: ID of separator token (default: 2)
32
+ mask_token_id: ID of mask token (default: 4)
33
+
34
+ Example:
35
+ >>> from helmbert import HELMBertConfig, HELMBertModel
36
+ >>> config = HELMBertConfig(hidden_size=768, num_hidden_layers=6)
37
+ >>> model = HELMBertModel(config)
38
+ """
39
+
40
+ model_type = "helmbert"
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_size: int = 78,
45
+ hidden_size: int = 768,
46
+ num_hidden_layers: int = 6,
47
+ num_attention_heads: int = 12,
48
+ intermediate_size: int = 3072,
49
+ hidden_dropout_prob: float = 0.1,
50
+ attention_probs_dropout_prob: float = 0.1,
51
+ max_position_embeddings: int = 512,
52
+ # Disentangled attention parameters
53
+ max_relative_positions: int = 512,
54
+ position_buckets: int = 256,
55
+ pos_att_type: str = "c2p|p2c",
56
+ share_att_key: bool = False,
57
+ # nGiE parameters
58
+ ngie_kernel_size: int = 3,
59
+ ngie_dropout: float = 0.1,
60
+ # Special token IDs
61
+ pad_token_id: int = 0,
62
+ bos_token_id: int = 1,
63
+ eos_token_id: int = 2,
64
+ sep_token_id: int = 2,
65
+ mask_token_id: int = 4,
66
+ # Classification/regression
67
+ num_labels: int = 2,
68
+ problem_type: str = None,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(
72
+ pad_token_id=pad_token_id,
73
+ bos_token_id=bos_token_id,
74
+ eos_token_id=eos_token_id,
75
+ **kwargs,
76
+ )
77
+
78
+ # Core transformer parameters
79
+ self.vocab_size = vocab_size
80
+ self.hidden_size = hidden_size
81
+ self.num_hidden_layers = num_hidden_layers
82
+ self.num_attention_heads = num_attention_heads
83
+ self.intermediate_size = intermediate_size
84
+ self.hidden_dropout_prob = hidden_dropout_prob
85
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
86
+ self.max_position_embeddings = max_position_embeddings
87
+
88
+ # Disentangled attention parameters
89
+ self.max_relative_positions = max_relative_positions
90
+ self.position_buckets = position_buckets
91
+ self.pos_att_type = pos_att_type
92
+ self.share_att_key = share_att_key
93
+
94
+ # nGiE parameters
95
+ self.ngie_kernel_size = ngie_kernel_size
96
+ self.ngie_dropout = ngie_dropout
97
+
98
+ # Special token IDs
99
+ self.sep_token_id = sep_token_id
100
+ self.mask_token_id = mask_token_id
101
+
102
+ # Classification/regression
103
+ self.num_labels = num_labels
104
+ self.problem_type = problem_type
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b2a6686b46de073e9a373938637fe0de1817cfb1fd2961f7daf1e4dafb7ced9
3
+ size 18489472
modeling_helmbert.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HELM-BERT model implementation.
2
+
3
+ This module implements the HELM-BERT model with:
4
+ - Disentangled attention (DeBERTa-style)
5
+ - Enhanced Mask Decoder (EMD) for MLM
6
+ - n-gram Induced Encoding (nGiE) layer
7
+ """
8
+
9
+ import math
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from packaging import version
17
+ from torch import _softmax_backward_data
18
+ from transformers import PreTrainedModel
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutput,
21
+ BaseModelOutputWithPooling,
22
+ MaskedLMOutput,
23
+ SequenceClassifierOutput,
24
+ )
25
+
26
+ from .configuration_helmbert import HELMBertConfig
27
+
28
+
29
+ # -----------------------------------------------------------------------------
30
+ # Utility Functions
31
+ # -----------------------------------------------------------------------------
32
+
33
+
34
+ def masked_layer_norm(
35
+ layer_norm: nn.LayerNorm, x: torch.Tensor, mask: Optional[torch.Tensor] = None
36
+ ) -> torch.Tensor:
37
+ """Apply LayerNorm with masking to avoid updates on padding tokens.
38
+
39
+ Args:
40
+ layer_norm: LayerNorm module
41
+ x: Input tensor (batch_size, seq_len, hidden_size)
42
+ mask: Mask tensor where 0 = padding (ignored), 1 = valid token
43
+
44
+ Returns:
45
+ Normalized tensor with padding positions zeroed out
46
+ """
47
+ output = layer_norm(x).to(x)
48
+ if mask is None:
49
+ return output
50
+ if mask.dim() != x.dim():
51
+ if mask.dim() == 4:
52
+ mask = mask.squeeze(1).squeeze(1)
53
+ mask = mask.unsqueeze(2)
54
+ mask = mask.to(output.dtype)
55
+ return output * mask
56
+
57
+
58
+ class XSoftmax(torch.autograd.Function):
59
+ """Masked Softmax optimized for memory efficiency."""
60
+
61
+ @staticmethod
62
+ def forward(ctx, input: torch.Tensor, mask: Optional[torch.Tensor], dim: int) -> torch.Tensor:
63
+ ctx.dim = dim
64
+ if mask is not None:
65
+ rmask = ~(mask.bool())
66
+ if rmask.dim() == 2:
67
+ rmask = rmask.unsqueeze(1).unsqueeze(2)
68
+ elif rmask.dim() == 3:
69
+ rmask = rmask.unsqueeze(2)
70
+ output = input.masked_fill(rmask, float("-inf"))
71
+ else:
72
+ output = input
73
+ output = torch.softmax(output, ctx.dim)
74
+ if mask is not None:
75
+ output.masked_fill_(rmask, 0)
76
+ ctx.save_for_backward(output)
77
+ return output
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
81
+ (output,) = ctx.saved_tensors
82
+ if version.Version(torch.__version__) >= version.Version("1.11.0"):
83
+ input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output.dtype)
84
+ else:
85
+ input_grad = _softmax_backward_data(grad_output, output, ctx.dim, output)
86
+ return input_grad, None, None
87
+
88
+
89
+ def build_relative_position(
90
+ query_size: int,
91
+ key_size: int,
92
+ bucket_size: int = -1,
93
+ max_position: int = 512,
94
+ device: Optional[torch.device] = None,
95
+ ) -> torch.Tensor:
96
+ """Build relative position matrix with optional log-bucketing."""
97
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
98
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
99
+ rel_pos = q_ids.unsqueeze(1) - k_ids.unsqueeze(0)
100
+
101
+ if bucket_size > 0:
102
+ rel_buckets = 0
103
+ num_buckets = bucket_size
104
+ rel_buckets += (rel_pos > 0).long() * (num_buckets // 2)
105
+ rel_pos = torch.abs(rel_pos)
106
+
107
+ max_exact = num_buckets // 4
108
+ is_small = rel_pos < max_exact
109
+
110
+ rel_pos_if_large = max_exact + (
111
+ torch.log(rel_pos.float() / max_exact)
112
+ / math.log(max_position / max_exact)
113
+ * (num_buckets // 4 - 1)
114
+ ).long()
115
+ rel_pos_if_large = torch.min(
116
+ rel_pos_if_large, torch.full_like(rel_pos_if_large, num_buckets // 2 - 1)
117
+ )
118
+
119
+ rel_buckets += torch.where(is_small, rel_pos, rel_pos_if_large)
120
+ return rel_buckets
121
+ else:
122
+ rel_pos = torch.clamp(rel_pos, -max_position, max_position)
123
+ return rel_pos + max_position
124
+
125
+
126
+ # -----------------------------------------------------------------------------
127
+ # Attention Modules
128
+ # -----------------------------------------------------------------------------
129
+
130
+
131
+ class DisentangledSelfAttention(nn.Module):
132
+ """Disentangled self-attention with content and position separation.
133
+
134
+ Implements content-to-content, content-to-position, and position-to-content
135
+ attention as described in DeBERTa.
136
+ """
137
+
138
+ def __init__(self, config: HELMBertConfig):
139
+ super().__init__()
140
+
141
+ if config.hidden_size % config.num_attention_heads != 0:
142
+ raise ValueError(
143
+ f"hidden_size ({config.hidden_size}) must be divisible by "
144
+ f"num_attention_heads ({config.num_attention_heads})"
145
+ )
146
+
147
+ self.num_heads = config.num_attention_heads
148
+ self.head_size = config.hidden_size // config.num_attention_heads
149
+ self.all_head_size = self.num_heads * self.head_size
150
+
151
+ # Content projections
152
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
153
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
154
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
155
+
156
+ # Position attention configuration
157
+ self.pos_att_type = [x.strip() for x in config.pos_att_type.lower().split("|")]
158
+ self.max_relative_positions = config.max_relative_positions
159
+ self.position_buckets = config.position_buckets
160
+ self.share_att_key = config.share_att_key
161
+
162
+ # Position embedding size
163
+ self.pos_ebd_size = config.max_relative_positions
164
+ if config.position_buckets > 0:
165
+ self.pos_ebd_size = config.position_buckets
166
+
167
+ # Position embeddings
168
+ self.rel_embeddings = nn.Embedding(self.pos_ebd_size * 2, config.hidden_size)
169
+
170
+ # Position projections
171
+ if not self.share_att_key:
172
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
173
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
174
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
175
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
176
+
177
+ # Dropout
178
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
179
+ self.pos_dropout = nn.Dropout(config.attention_probs_dropout_prob)
180
+
181
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
182
+ """Reshape tensor for attention computation."""
183
+ new_shape = x.size()[:-1] + (self.num_heads, self.head_size)
184
+ x = x.view(*new_shape)
185
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
186
+
187
+ def forward(
188
+ self,
189
+ hidden_states: torch.Tensor,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ output_attentions: bool = False,
192
+ query_states: Optional[torch.Tensor] = None,
193
+ relative_pos: Optional[torch.Tensor] = None,
194
+ rel_embeddings: Optional[torch.Tensor] = None,
195
+ ) -> Dict[str, Any]:
196
+ """Forward pass of disentangled attention."""
197
+ if query_states is None:
198
+ query_states = hidden_states
199
+
200
+ # Compute Q, K, V
201
+ query_layer = self.transpose_for_scores(self.query_proj(query_states)).float()
202
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states)).float()
203
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states))
204
+
205
+ # Calculate scale factor
206
+ scale_factor = 1
207
+ if "c2p" in self.pos_att_type:
208
+ scale_factor += 1
209
+ if "p2c" in self.pos_att_type:
210
+ scale_factor += 1
211
+ if "p2p" in self.pos_att_type:
212
+ scale_factor += 1
213
+
214
+ scale = 1.0 / math.sqrt(self.head_size * scale_factor)
215
+
216
+ # Content-to-content attention (c2c)
217
+ c2c_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) * scale)
218
+ attention_scores = c2c_scores
219
+
220
+ # Add relative position bias if enabled
221
+ if len(self.pos_att_type) > 0 and self.pos_att_type[0]:
222
+ rel_att = self._disentangled_attention_bias(
223
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
224
+ )
225
+ if rel_att is not None:
226
+ attention_scores = attention_scores + rel_att
227
+
228
+ # Normalize scores for numerical stability
229
+ attention_scores = attention_scores - attention_scores.max(dim=-1, keepdim=True)[0].detach()
230
+ attention_scores = attention_scores.to(hidden_states.dtype)
231
+
232
+ # Reshape for XSoftmax
233
+ attention_scores = attention_scores.view(
234
+ -1, self.num_heads, attention_scores.size(-2), attention_scores.size(-1)
235
+ )
236
+
237
+ # Apply XSoftmax
238
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
239
+ attention_probs = self.dropout(attention_probs)
240
+
241
+ # Apply attention to values
242
+ attention_probs_flat = attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1))
243
+ context_layer = torch.bmm(attention_probs_flat, value_layer)
244
+
245
+ # Reshape output
246
+ context_layer = context_layer.view(-1, self.num_heads, context_layer.size(-2), context_layer.size(-1))
247
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
248
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
249
+ context_layer = context_layer.view(*new_shape)
250
+
251
+ result = {"hidden_states": context_layer, "attention_probs": attention_probs}
252
+ return result
253
+
254
+ def _disentangled_attention_bias(
255
+ self,
256
+ query_layer: torch.Tensor,
257
+ key_layer: torch.Tensor,
258
+ relative_pos: Optional[torch.Tensor],
259
+ rel_embeddings: Optional[torch.Tensor],
260
+ scale_factor: int,
261
+ ) -> Optional[torch.Tensor]:
262
+ """Compute disentangled attention bias."""
263
+ if relative_pos is None:
264
+ q_size = query_layer.size(-2)
265
+ k_size = key_layer.size(-2)
266
+ relative_pos = build_relative_position(
267
+ q_size,
268
+ k_size,
269
+ bucket_size=self.position_buckets,
270
+ max_position=self.max_relative_positions,
271
+ device=query_layer.device,
272
+ )
273
+
274
+ if relative_pos.dim() == 2:
275
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
276
+ elif relative_pos.dim() == 3:
277
+ relative_pos = relative_pos.unsqueeze(1)
278
+
279
+ batch_size = query_layer.size(0) // self.num_heads
280
+
281
+ # Get position embeddings
282
+ if rel_embeddings is None:
283
+ rel_embeddings = self.rel_embeddings.weight
284
+
285
+ att_span = self.pos_ebd_size
286
+ rel_embeddings = rel_embeddings[
287
+ self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :
288
+ ].unsqueeze(0)
289
+ rel_embeddings = self.pos_dropout(rel_embeddings)
290
+
291
+ score = torch.zeros_like(query_layer[:, :, :1]).expand(-1, -1, key_layer.size(-2))
292
+
293
+ # Prepare position indices
294
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
295
+ c2p_pos = c2p_pos.squeeze(0).expand(query_layer.size(0), query_layer.size(1), relative_pos.size(-1))
296
+
297
+ # Content-to-position (c2p)
298
+ if "c2p" in self.pos_att_type:
299
+ pos_key_layer = (
300
+ self.pos_key_proj(rel_embeddings) if not self.share_att_key else self.key_proj(rel_embeddings)
301
+ )
302
+ pos_key_layer = self.transpose_for_scores(pos_key_layer).repeat(batch_size, 1, 1)
303
+
304
+ c2p_scale = 1.0 / math.sqrt(self.head_size * scale_factor)
305
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2) * c2p_scale)
306
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos)
307
+ score = score + c2p_att
308
+
309
+ # Position-to-content (p2c)
310
+ if "p2c" in self.pos_att_type:
311
+ pos_query_layer = (
312
+ self.pos_query_proj(rel_embeddings) if not self.share_att_key else self.query_proj(rel_embeddings)
313
+ )
314
+ pos_query_layer = self.transpose_for_scores(pos_query_layer).repeat(batch_size, 1, 1)
315
+
316
+ p2c_scale = 1.0 / math.sqrt(self.head_size * scale_factor)
317
+ p2c_att = torch.bmm(pos_query_layer * p2c_scale, key_layer.transpose(-1, -2))
318
+ p2c_att = torch.gather(p2c_att, dim=-2, index=c2p_pos)
319
+ score = score + p2c_att
320
+
321
+ return score
322
+
323
+
324
+ # -----------------------------------------------------------------------------
325
+ # Transformer Components
326
+ # -----------------------------------------------------------------------------
327
+
328
+
329
+ class HELMBertEmbeddings(nn.Module):
330
+ """Token and position embeddings for HELM-BERT."""
331
+
332
+ def __init__(self, config: HELMBertConfig):
333
+ super().__init__()
334
+ self.word_embeddings = nn.Embedding(
335
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
336
+ )
337
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
338
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
339
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
340
+
341
+ def forward(
342
+ self,
343
+ input_ids: torch.Tensor,
344
+ attention_mask: Optional[torch.Tensor] = None,
345
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
346
+ """Forward pass.
347
+
348
+ Returns:
349
+ Tuple of (token_embeddings, position_embeddings)
350
+ """
351
+ batch_size, seq_len = input_ids.shape
352
+
353
+ # Token embeddings
354
+ embeddings = self.word_embeddings(input_ids)
355
+
356
+ # Position embeddings
357
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
358
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
359
+ position_embeds = self.position_embeddings(position_ids)
360
+
361
+ # Layer norm and dropout
362
+ embeddings = masked_layer_norm(self.layer_norm, embeddings, attention_mask)
363
+ embeddings = self.dropout(embeddings)
364
+
365
+ return embeddings, position_embeds
366
+
367
+
368
+ class NgieLayer(nn.Module):
369
+ """n-gram Induced Input Encoding (nGiE) layer.
370
+
371
+ Captures local n-gram patterns using 1D convolution.
372
+ """
373
+
374
+ def __init__(self, config: HELMBertConfig):
375
+ super().__init__()
376
+
377
+ self.conv = nn.Conv1d(
378
+ in_channels=config.hidden_size,
379
+ out_channels=config.hidden_size,
380
+ kernel_size=config.ngie_kernel_size,
381
+ padding=(config.ngie_kernel_size - 1) // 2,
382
+ groups=1,
383
+ )
384
+ self.activation = nn.Tanh()
385
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
386
+ self.dropout = nn.Dropout(config.ngie_dropout)
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.Tensor,
391
+ residual_states: torch.Tensor,
392
+ attention_mask: torch.Tensor,
393
+ ) -> torch.Tensor:
394
+ """Forward pass.
395
+
396
+ Args:
397
+ hidden_states: Input to convolution (batch, seq, hidden)
398
+ residual_states: States for residual connection (batch, seq, hidden)
399
+ attention_mask: Mask where 1 = valid, 0 = padding
400
+
401
+ Returns:
402
+ Output with n-gram information incorporated
403
+ """
404
+ # Apply 1D convolution
405
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
406
+
407
+ # Create reverse mask for padding
408
+ if version.Version(torch.__version__) >= version.Version("1.2.0a"):
409
+ rmask = (1 - attention_mask).bool()
410
+ else:
411
+ rmask = (1 - attention_mask).byte()
412
+
413
+ # Zero out padding positions
414
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
415
+
416
+ # Apply activation and dropout
417
+ out = self.activation(self.dropout(out))
418
+
419
+ # Residual connection with LayerNorm
420
+ output_states = masked_layer_norm(self.layer_norm, residual_states + out, attention_mask)
421
+
422
+ return output_states
423
+
424
+
425
+ class TransformerBlock(nn.Module):
426
+ """Transformer block with disentangled attention and GELU FFN."""
427
+
428
+ def __init__(self, config: HELMBertConfig):
429
+ super().__init__()
430
+
431
+ self.self_attn = DisentangledSelfAttention(config)
432
+ self.attn_output_dense = nn.Linear(config.hidden_size, config.hidden_size)
433
+
434
+ # FFN with GELU
435
+ self.linear1 = nn.Sequential(
436
+ nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU()
437
+ )
438
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
439
+
440
+ # Normalization and dropout
441
+ self.norm1 = nn.LayerNorm(config.hidden_size)
442
+ self.norm2 = nn.LayerNorm(config.hidden_size)
443
+ self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
444
+ self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
445
+
446
+ def forward(
447
+ self,
448
+ src: torch.Tensor,
449
+ src_key_padding_mask: Optional[torch.Tensor] = None,
450
+ output_attentions: bool = False,
451
+ query_states: Optional[torch.Tensor] = None,
452
+ relative_pos: Optional[torch.Tensor] = None,
453
+ rel_embeddings: Optional[torch.Tensor] = None,
454
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
455
+ """Forward pass.
456
+
457
+ Args:
458
+ src: Input embeddings [seq_len, batch, hidden]
459
+ src_key_padding_mask: Padding mask [batch, seq_len]
460
+ output_attentions: Whether to return attention weights
461
+ query_states: Optional query for EMD
462
+ relative_pos: Relative position indices
463
+ rel_embeddings: Relative position embeddings
464
+
465
+ Returns:
466
+ Tuple of (output, optional attention weights)
467
+ """
468
+ # Transpose for attention [seq, batch, hidden] -> [batch, seq, hidden]
469
+ src_transposed = src.transpose(0, 1)
470
+
471
+ # Convert padding mask to attention mask (1=valid, 0=padding)
472
+ attention_mask = None
473
+ if src_key_padding_mask is not None:
474
+ attention_mask = (~src_key_padding_mask).float()
475
+
476
+ query_states_transposed = None
477
+ if query_states is not None:
478
+ query_states_transposed = query_states.transpose(0, 1)
479
+
480
+ # Self-attention
481
+ attn_result = self.self_attn(
482
+ src_transposed,
483
+ attention_mask,
484
+ output_attentions=output_attentions,
485
+ query_states=query_states_transposed,
486
+ relative_pos=relative_pos,
487
+ rel_embeddings=rel_embeddings,
488
+ )
489
+ attn_output = attn_result["hidden_states"].transpose(0, 1)
490
+ attn_weights = attn_result.get("attention_probs") if output_attentions else None
491
+
492
+ # Dense projection
493
+ attn_output = self.attn_output_dense(attn_output)
494
+
495
+ # Residual connection
496
+ residual_input = query_states if query_states is not None else src
497
+ src = residual_input + self.dropout1(attn_output)
498
+
499
+ # LayerNorm
500
+ src = src.transpose(0, 1)
501
+ src = masked_layer_norm(self.norm1, src)
502
+ src = src.transpose(0, 1)
503
+
504
+ # FFN
505
+ ff_output = self.linear1(src)
506
+ ff_output = self.linear2(ff_output)
507
+ ff_output = self.dropout2(ff_output)
508
+ src = src + ff_output
509
+
510
+ # LayerNorm
511
+ src = src.transpose(0, 1)
512
+ src = masked_layer_norm(self.norm2, src)
513
+ src = src.transpose(0, 1)
514
+
515
+ return src, attn_weights
516
+
517
+
518
+ class HELMBertEncoder(nn.Module):
519
+ """Stack of transformer blocks with nGiE layer."""
520
+
521
+ def __init__(self, config: HELMBertConfig):
522
+ super().__init__()
523
+ self.config = config
524
+
525
+ # nGiE layer (applied after first transformer block)
526
+ self.ngie_layer = NgieLayer(config)
527
+
528
+ # Transformer blocks
529
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
530
+
531
+ def get_rel_embedding(self) -> Optional[torch.Tensor]:
532
+ """Get relative position embeddings from first layer."""
533
+ if len(self.layers) > 0:
534
+ first_layer = self.layers[0]
535
+ if hasattr(first_layer, "self_attn") and hasattr(first_layer.self_attn, "rel_embeddings"):
536
+ return first_layer.self_attn.rel_embeddings.weight
537
+ return None
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states: torch.Tensor,
542
+ attention_mask: Optional[torch.Tensor] = None,
543
+ position_embeddings: Optional[torch.Tensor] = None,
544
+ output_attentions: bool = False,
545
+ output_hidden_states: bool = False,
546
+ use_emd: bool = False,
547
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], Optional[Tuple]]:
548
+ """Forward pass.
549
+
550
+ Args:
551
+ hidden_states: Input embeddings [batch, seq, hidden]
552
+ attention_mask: Attention mask [batch, seq]
553
+ position_embeddings: Position embeddings for EMD
554
+ output_attentions: Whether to return attention weights
555
+ output_hidden_states: Whether to return all hidden states
556
+ use_emd: Whether to use Enhanced Mask Decoder
557
+
558
+ Returns:
559
+ Tuple of (last_hidden_state, emd_output, all_hidden_states, all_attentions)
560
+ """
561
+ all_hidden_states = () if output_hidden_states else None
562
+ all_attentions = () if output_attentions else None
563
+
564
+ # Store for nGiE
565
+ ngie_input_states = hidden_states.clone()
566
+
567
+ # [batch, seq, hidden] -> [seq, batch, hidden]
568
+ hidden_states = hidden_states.transpose(0, 1)
569
+
570
+ # Key padding mask (True = padding)
571
+ key_padding_mask = None
572
+ if attention_mask is not None:
573
+ key_padding_mask = ~attention_mask.bool()
574
+
575
+ # Store layer[-2] for EMD
576
+ layer_minus_2 = None
577
+ num_layers = len(self.layers)
578
+
579
+ for layer_idx, layer in enumerate(self.layers):
580
+ if output_hidden_states:
581
+ all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),)
582
+
583
+ hidden_states, attn_weights = layer(
584
+ hidden_states,
585
+ src_key_padding_mask=key_padding_mask,
586
+ output_attentions=output_attentions,
587
+ )
588
+
589
+ if output_attentions and attn_weights is not None:
590
+ all_attentions = all_attentions + (attn_weights,)
591
+
592
+ # Apply nGiE after first layer
593
+ if layer_idx == 0:
594
+ hidden_states_batch = hidden_states.transpose(0, 1)
595
+ hidden_states_batch = self.ngie_layer(ngie_input_states, hidden_states_batch, attention_mask)
596
+ hidden_states = hidden_states_batch.transpose(0, 1)
597
+
598
+ # Store layer[-2] for EMD
599
+ if use_emd and layer_idx == num_layers - 2:
600
+ layer_minus_2 = hidden_states.clone()
601
+
602
+ # Convert back to [batch, seq, hidden]
603
+ hidden_states = hidden_states.transpose(0, 1)
604
+
605
+ if output_hidden_states:
606
+ all_hidden_states = all_hidden_states + (hidden_states,)
607
+
608
+ # Enhanced Mask Decoder (EMD) for MLM
609
+ emd_output = None
610
+ if use_emd and layer_minus_2 is not None and position_embeddings is not None:
611
+ emd_keys_values = layer_minus_2
612
+ emd_query = layer_minus_2.transpose(0, 1)
613
+ emd_query = position_embeddings + emd_query
614
+ emd_query = emd_query.transpose(0, 1)
615
+
616
+ rel_embeddings = self.get_rel_embedding()
617
+ last_layer = self.layers[-1]
618
+
619
+ for _ in range(2):
620
+ emd_query, _ = last_layer(
621
+ emd_keys_values,
622
+ src_key_padding_mask=key_padding_mask,
623
+ query_states=emd_query,
624
+ relative_pos=None,
625
+ rel_embeddings=rel_embeddings,
626
+ )
627
+
628
+ emd_output = emd_query.transpose(0, 1)
629
+
630
+ return hidden_states, emd_output, all_hidden_states, all_attentions
631
+
632
+
633
+ class HELMBertPooler(nn.Module):
634
+ """Mean pooling over sequence."""
635
+
636
+ def __init__(self, config: HELMBertConfig):
637
+ super().__init__()
638
+ self.hidden_size = config.hidden_size
639
+
640
+ def forward(
641
+ self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
642
+ ) -> torch.Tensor:
643
+ """Apply mean pooling.
644
+
645
+ Args:
646
+ hidden_states: [batch, seq, hidden]
647
+ attention_mask: [batch, seq]
648
+
649
+ Returns:
650
+ Pooled output [batch, hidden]
651
+ """
652
+ if attention_mask is not None:
653
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
654
+ sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
655
+ eps = torch.finfo(hidden_states.dtype).eps
656
+ sum_mask = torch.clamp(mask_expanded.sum(1), min=eps)
657
+ return sum_embeddings / sum_mask
658
+ else:
659
+ return hidden_states.mean(dim=1)
660
+
661
+
662
+ # -----------------------------------------------------------------------------
663
+ # Pre-trained Model Base
664
+ # -----------------------------------------------------------------------------
665
+
666
+
667
+ class HELMBertPreTrainedModel(PreTrainedModel):
668
+ """Base class for HELM-BERT models."""
669
+
670
+ config_class = HELMBertConfig
671
+ base_model_prefix = "helmbert"
672
+ supports_gradient_checkpointing = True
673
+
674
+ def _init_weights(self, module: nn.Module) -> None:
675
+ """Initialize weights with BERT-style initialization."""
676
+ if isinstance(module, nn.Linear):
677
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
678
+ if module.bias is not None:
679
+ nn.init.zeros_(module.bias)
680
+ elif isinstance(module, nn.Embedding):
681
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
682
+ if module.padding_idx is not None:
683
+ module.weight.data[module.padding_idx].zero_()
684
+ elif isinstance(module, nn.LayerNorm):
685
+ nn.init.ones_(module.weight)
686
+ nn.init.zeros_(module.bias)
687
+
688
+
689
+ # -----------------------------------------------------------------------------
690
+ # Model Classes
691
+ # -----------------------------------------------------------------------------
692
+
693
+
694
+ class HELMBertModel(HELMBertPreTrainedModel):
695
+ """HELM-BERT base model.
696
+
697
+ This model outputs the last hidden states and optionally pooled output.
698
+
699
+ Example:
700
+ >>> from helmbert import HELMBertModel, HELMBertTokenizer
701
+ >>> tokenizer = HELMBertTokenizer()
702
+ >>> model = HELMBertModel.from_pretrained("./checkpoints/helmbert-base")
703
+ >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt")
704
+ >>> outputs = model(**inputs)
705
+ >>> last_hidden_state = outputs.last_hidden_state
706
+ >>> pooler_output = outputs.pooler_output
707
+ """
708
+
709
+ def __init__(self, config: HELMBertConfig):
710
+ super().__init__(config)
711
+ self.config = config
712
+
713
+ self.embeddings = HELMBertEmbeddings(config)
714
+ self.encoder = HELMBertEncoder(config)
715
+ self.pooler = HELMBertPooler(config)
716
+
717
+ self.post_init()
718
+
719
+ def get_input_embeddings(self) -> nn.Embedding:
720
+ return self.embeddings.word_embeddings
721
+
722
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
723
+ self.embeddings.word_embeddings = value
724
+
725
+ def forward(
726
+ self,
727
+ input_ids: torch.Tensor,
728
+ attention_mask: Optional[torch.Tensor] = None,
729
+ output_attentions: bool = False,
730
+ output_hidden_states: bool = False,
731
+ return_dict: bool = True,
732
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
733
+ """Forward pass.
734
+
735
+ Args:
736
+ input_ids: Token IDs [batch, seq]
737
+ attention_mask: Attention mask [batch, seq]
738
+ output_attentions: Whether to return attention weights
739
+ output_hidden_states: Whether to return all hidden states
740
+ return_dict: Whether to return a ModelOutput
741
+
742
+ Returns:
743
+ BaseModelOutputWithPooling or tuple
744
+ """
745
+ if attention_mask is None:
746
+ attention_mask = torch.ones_like(input_ids)
747
+
748
+ # Embeddings
749
+ embeddings, position_embeddings = self.embeddings(input_ids, attention_mask)
750
+
751
+ # Encoder
752
+ encoder_outputs = self.encoder(
753
+ embeddings,
754
+ attention_mask=attention_mask,
755
+ position_embeddings=position_embeddings,
756
+ output_attentions=output_attentions,
757
+ output_hidden_states=output_hidden_states,
758
+ use_emd=False,
759
+ )
760
+
761
+ last_hidden_state = encoder_outputs[0]
762
+ hidden_states = encoder_outputs[2]
763
+ attentions = encoder_outputs[3]
764
+
765
+ # Pooling
766
+ pooler_output = self.pooler(last_hidden_state, attention_mask)
767
+
768
+ if not return_dict:
769
+ return (last_hidden_state, pooler_output, hidden_states, attentions)
770
+
771
+ return BaseModelOutputWithPooling(
772
+ last_hidden_state=last_hidden_state,
773
+ pooler_output=pooler_output,
774
+ hidden_states=hidden_states,
775
+ attentions=attentions,
776
+ )
777
+
778
+
779
+ class HELMBertLMHead(nn.Module):
780
+ """MLM head with weight tying."""
781
+
782
+ def __init__(self, config: HELMBertConfig):
783
+ super().__init__()
784
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
785
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
786
+ self.activation = nn.GELU()
787
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
788
+
789
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
790
+ """Forward pass.
791
+
792
+ Args:
793
+ hidden_states: [batch, seq, hidden]
794
+
795
+ Returns:
796
+ Logits [batch, seq, vocab]
797
+ """
798
+ hidden_states = self.dense(hidden_states)
799
+ hidden_states = self.activation(hidden_states)
800
+ hidden_states = self.layer_norm(hidden_states)
801
+ logits = self.decoder(hidden_states)
802
+ return logits
803
+
804
+
805
+ class HELMBertForMaskedLM(HELMBertPreTrainedModel):
806
+ """HELM-BERT for Masked Language Modeling with Enhanced Mask Decoder (EMD).
807
+
808
+ Example:
809
+ >>> from helmbert import HELMBertForMaskedLM, HELMBertTokenizer
810
+ >>> tokenizer = HELMBertTokenizer()
811
+ >>> model = HELMBertForMaskedLM.from_pretrained("./checkpoints/helmbert-base")
812
+ >>> inputs = tokenizer("PEPTIDE1{A.¶.D.E}$$$$", return_tensors="pt") # ¶ is mask
813
+ >>> outputs = model(**inputs)
814
+ >>> predictions = outputs.logits.argmax(dim=-1)
815
+ """
816
+
817
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
818
+
819
+ def __init__(self, config: HELMBertConfig):
820
+ super().__init__(config)
821
+ self.helmbert = HELMBertModel(config)
822
+ self.lm_head = HELMBertLMHead(config)
823
+
824
+ self.post_init()
825
+
826
+ def get_output_embeddings(self) -> nn.Linear:
827
+ return self.lm_head.decoder
828
+
829
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
830
+ self.lm_head.decoder = new_embeddings
831
+
832
+ def forward(
833
+ self,
834
+ input_ids: torch.Tensor,
835
+ attention_mask: Optional[torch.Tensor] = None,
836
+ labels: Optional[torch.Tensor] = None,
837
+ output_attentions: bool = False,
838
+ output_hidden_states: bool = False,
839
+ return_dict: bool = True,
840
+ use_emd: bool = True,
841
+ ) -> Union[Tuple, MaskedLMOutput]:
842
+ """Forward pass.
843
+
844
+ Args:
845
+ input_ids: Token IDs [batch, seq]
846
+ attention_mask: Attention mask [batch, seq]
847
+ labels: Labels for MLM (-100 for non-masked tokens)
848
+ output_attentions: Whether to return attention weights
849
+ output_hidden_states: Whether to return all hidden states
850
+ return_dict: Whether to return a ModelOutput
851
+ use_emd: Whether to use Enhanced Mask Decoder
852
+
853
+ Returns:
854
+ MaskedLMOutput or tuple
855
+ """
856
+ if attention_mask is None:
857
+ attention_mask = torch.ones_like(input_ids)
858
+
859
+ # Embeddings
860
+ embeddings, position_embeddings = self.helmbert.embeddings(input_ids, attention_mask)
861
+
862
+ # Encoder with optional EMD
863
+ encoder_outputs = self.helmbert.encoder(
864
+ embeddings,
865
+ attention_mask=attention_mask,
866
+ position_embeddings=position_embeddings,
867
+ output_attentions=output_attentions,
868
+ output_hidden_states=output_hidden_states,
869
+ use_emd=use_emd,
870
+ )
871
+
872
+ # Use EMD output if available, otherwise use last hidden state
873
+ if use_emd and encoder_outputs[1] is not None:
874
+ sequence_output = encoder_outputs[1]
875
+ else:
876
+ sequence_output = encoder_outputs[0]
877
+
878
+ hidden_states = encoder_outputs[2]
879
+ attentions = encoder_outputs[3]
880
+
881
+ # MLM head
882
+ prediction_scores = self.lm_head(sequence_output)
883
+
884
+ # Calculate loss if labels provided
885
+ loss = None
886
+ if labels is not None:
887
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
888
+ loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
889
+
890
+ if not return_dict:
891
+ output = (prediction_scores, hidden_states, attentions)
892
+ return ((loss,) + output) if loss is not None else output
893
+
894
+ return MaskedLMOutput(
895
+ loss=loss,
896
+ logits=prediction_scores,
897
+ hidden_states=hidden_states,
898
+ attentions=attentions,
899
+ )
900
+
901
+
902
+ class HELMBertForSequenceClassification(HELMBertPreTrainedModel):
903
+ """HELM-BERT for sequence classification/regression.
904
+
905
+ Example:
906
+ >>> from helmbert import HELMBertForSequenceClassification, HELMBertConfig
907
+ >>> config = HELMBertConfig(num_labels=1) # Regression
908
+ >>> model = HELMBertForSequenceClassification(config)
909
+ """
910
+
911
+ def __init__(self, config: HELMBertConfig):
912
+ super().__init__(config)
913
+ self.num_labels = config.num_labels
914
+ self.config = config
915
+
916
+ self.helmbert = HELMBertModel(config)
917
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
918
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
919
+
920
+ self.post_init()
921
+
922
+ def forward(
923
+ self,
924
+ input_ids: torch.Tensor,
925
+ attention_mask: Optional[torch.Tensor] = None,
926
+ labels: Optional[torch.Tensor] = None,
927
+ output_attentions: bool = False,
928
+ output_hidden_states: bool = False,
929
+ return_dict: bool = True,
930
+ ) -> Union[Tuple, SequenceClassifierOutput]:
931
+ """Forward pass.
932
+
933
+ Args:
934
+ input_ids: Token IDs [batch, seq]
935
+ attention_mask: Attention mask [batch, seq]
936
+ labels: Labels for classification/regression
937
+ output_attentions: Whether to return attention weights
938
+ output_hidden_states: Whether to return all hidden states
939
+ return_dict: Whether to return a ModelOutput
940
+
941
+ Returns:
942
+ SequenceClassifierOutput or tuple
943
+ """
944
+ outputs = self.helmbert(
945
+ input_ids,
946
+ attention_mask=attention_mask,
947
+ output_attentions=output_attentions,
948
+ output_hidden_states=output_hidden_states,
949
+ return_dict=True,
950
+ )
951
+
952
+ pooled_output = outputs.pooler_output
953
+ pooled_output = self.dropout(pooled_output)
954
+ logits = self.classifier(pooled_output)
955
+
956
+ loss = None
957
+ if labels is not None:
958
+ if self.config.problem_type is None:
959
+ if self.num_labels == 1:
960
+ self.config.problem_type = "regression"
961
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
962
+ self.config.problem_type = "single_label_classification"
963
+ else:
964
+ self.config.problem_type = "multi_label_classification"
965
+
966
+ if self.config.problem_type == "regression":
967
+ loss_fct = nn.MSELoss()
968
+ if self.num_labels == 1:
969
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
970
+ else:
971
+ loss = loss_fct(logits, labels)
972
+ elif self.config.problem_type == "single_label_classification":
973
+ loss_fct = nn.CrossEntropyLoss()
974
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
975
+ elif self.config.problem_type == "multi_label_classification":
976
+ loss_fct = nn.BCEWithLogitsLoss()
977
+ loss = loss_fct(logits, labels)
978
+
979
+ if not return_dict:
980
+ output = (logits,) + outputs[2:]
981
+ return ((loss,) + output) if loss is not None else output
982
+
983
+ return SequenceClassifierOutput(
984
+ loss=loss,
985
+ logits=logits,
986
+ hidden_states=outputs.hidden_states,
987
+ attentions=outputs.attentions,
988
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "@",
3
+ "cls_token": "@",
4
+ "eos_token": "\n",
5
+ "mask_token": "¶",
6
+ "pad_token": " ",
7
+ "sep_token": "\n",
8
+ "unk_token": "§"
9
+ }
tokenization_helmbert.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HELM-BERT tokenizer."""
2
+
3
+ import json
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ from transformers import PreTrainedTokenizer
8
+
9
+
10
+ # Default vocabulary for HELM notation
11
+ HELM_VOCAB = {
12
+ # Special tokens (0-4)
13
+ " ": 0, # PAD
14
+ "@": 1, # BOS/CLS
15
+ "\n": 2, # EOS/SEP
16
+ "§": 3, # UNK
17
+ "¶": 4, # MASK
18
+
19
+ # Natural amino acids (5-25)
20
+ "A": 5, "R": 6, "N": 7, "D": 8, "C": 9,
21
+ "E": 10, "Q": 11, "G": 12, "H": 13, "I": 14,
22
+ "L": 15, "K": 16, "M": 17, "F": 18, "P": 19,
23
+ "S": 20, "T": 21, "W": 22, "Y": 23, "V": 24,
24
+ "X": 25, # Unknown amino acid
25
+
26
+ # Structure symbols (26-37)
27
+ "[": 26, "]": 27, "{": 28, "}": 29, "(": 30, ")": 31,
28
+ "$": 32, ",": 33, ":": 34, "|": 35, "-": 36, ".": 37,
29
+
30
+ # Numbers (38-47)
31
+ "0": 38, "1": 39, "2": 40, "3": 41, "4": 42,
32
+ "5": 43, "6": 44, "7": 45, "8": 46, "9": 47,
33
+
34
+ # Uppercase non-amino acids (48-50)
35
+ "B": 48, "O": 49, ">": 50,
36
+
37
+ # Lowercase letters (51-72)
38
+ "a": 51, "b": 52, "c": 53, "d": 54, "e": 55,
39
+ "f": 56, "g": 57, "h": 58, "i": 59, "l": 60,
40
+ "m": 61, "n": 62, "o": 63, "p": 64, "r": 65,
41
+ "s": 66, "t": 67, "u": 68, "v": 69, "x": 70,
42
+ "y": 71, "z": 72,
43
+
44
+ # Encoded polymer markers (73-76)
45
+ "/": 73, # PEPTIDE
46
+ "*": 74, # me
47
+ "\t": 75, # am
48
+ "&": 76, # ac
49
+
50
+ # Miscellaneous (77)
51
+ "_": 77,
52
+ }
53
+
54
+ # Multi-character to single-character encoding
55
+ HELM_ENCODE_MAP = {"PEPTIDE": "/", "me": "*", "am": "\t", "ac": "&"}
56
+ HELM_DECODE_MAP = {v: k for k, v in HELM_ENCODE_MAP.items()}
57
+
58
+
59
+ class HELMBertTokenizer(PreTrainedTokenizer):
60
+ """Tokenizer for HELM-BERT.
61
+
62
+ This tokenizer handles HELM (Hierarchical Editing Language for Macromolecules)
63
+ notation, converting peptide sequences into token IDs for the HELM-BERT model.
64
+
65
+ The tokenizer uses character-level tokenization with special handling for
66
+ multi-character HELM tokens like "PEPTIDE", "me", "am", "ac".
67
+
68
+ Example:
69
+ >>> from helmbert import HELMBertTokenizer
70
+ >>> tokenizer = HELMBertTokenizer()
71
+ >>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt")
72
+ >>> inputs.input_ids
73
+ tensor([[ 1, 73, 39, 28, 5, 37, 9, 37, 8, 37, 10, 29, 32, 32, 32, 32, 2]])
74
+ """
75
+
76
+ vocab_files_names = {"vocab_file": "vocab.json"}
77
+ model_input_names = ["input_ids", "attention_mask"]
78
+
79
+ def __init__(
80
+ self,
81
+ vocab_file: Optional[str] = None,
82
+ unk_token: str = "§",
83
+ sep_token: str = "\n",
84
+ pad_token: str = " ",
85
+ cls_token: str = "@",
86
+ mask_token: str = "¶",
87
+ bos_token: str = "@",
88
+ eos_token: str = "\n",
89
+ model_max_length: int = 512,
90
+ **kwargs,
91
+ ):
92
+ # Load or create vocabulary
93
+ if vocab_file is not None and os.path.isfile(vocab_file):
94
+ with open(vocab_file, encoding="utf-8") as f:
95
+ self.vocab = json.load(f)
96
+ else:
97
+ self.vocab = HELM_VOCAB.copy()
98
+
99
+ self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
100
+
101
+ # HELM encoding/decoding maps
102
+ self.encode_map = HELM_ENCODE_MAP.copy()
103
+ self.decode_map = HELM_DECODE_MAP.copy()
104
+
105
+ super().__init__(
106
+ unk_token=unk_token,
107
+ sep_token=sep_token,
108
+ pad_token=pad_token,
109
+ cls_token=cls_token,
110
+ mask_token=mask_token,
111
+ bos_token=bos_token,
112
+ eos_token=eos_token,
113
+ model_max_length=model_max_length,
114
+ **kwargs,
115
+ )
116
+
117
+ @property
118
+ def vocab_size(self) -> int:
119
+ """Return the vocabulary size."""
120
+ return len(self.vocab)
121
+
122
+ def get_vocab(self) -> Dict[str, int]:
123
+ """Return the vocabulary as a dictionary."""
124
+ return self.vocab.copy()
125
+
126
+ def _encode_helm(self, text: str) -> str:
127
+ """Encode multi-character HELM tokens to single characters.
128
+
129
+ Args:
130
+ text: Raw HELM notation string
131
+
132
+ Returns:
133
+ Encoded string with single-character tokens
134
+ """
135
+ if not text:
136
+ return ""
137
+ result = text
138
+ for seq, tok in self.encode_map.items():
139
+ result = result.replace(seq, tok)
140
+ return result
141
+
142
+ def _decode_helm(self, text: str) -> str:
143
+ """Decode single-character tokens back to multi-character HELM tokens.
144
+
145
+ Args:
146
+ text: Encoded string with single-character tokens
147
+
148
+ Returns:
149
+ Decoded HELM notation string
150
+ """
151
+ if not text:
152
+ return ""
153
+ result = text
154
+ for tok, seq in self.decode_map.items():
155
+ result = result.replace(tok, seq)
156
+ return result
157
+
158
+ def _tokenize(self, text: str) -> List[str]:
159
+ """Tokenize a HELM string into a list of tokens.
160
+
161
+ Args:
162
+ text: HELM notation string
163
+
164
+ Returns:
165
+ List of single-character tokens
166
+ """
167
+ # First encode multi-character tokens to single characters
168
+ encoded = self._encode_helm(text)
169
+ # Return as list of characters
170
+ return list(encoded)
171
+
172
+ def _convert_token_to_id(self, token: str) -> int:
173
+ """Convert a token to its ID."""
174
+ return self.vocab.get(token, self.vocab.get(self.unk_token, 3))
175
+
176
+ def _convert_id_to_token(self, index: int) -> str:
177
+ """Convert an ID to its token."""
178
+ return self.ids_to_tokens.get(index, self.unk_token)
179
+
180
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
181
+ """Convert a list of tokens to a HELM string.
182
+
183
+ Args:
184
+ tokens: List of tokens
185
+
186
+ Returns:
187
+ Decoded HELM notation string
188
+ """
189
+ # Join tokens and decode back to HELM notation
190
+ joined = "".join(tokens)
191
+ return self._decode_helm(joined)
192
+
193
+ def build_inputs_with_special_tokens(
194
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
195
+ ) -> List[int]:
196
+ """Build model inputs by adding special tokens.
197
+
198
+ Args:
199
+ token_ids_0: First sequence of token IDs
200
+ token_ids_1: Optional second sequence of token IDs
201
+
202
+ Returns:
203
+ List of token IDs with special tokens added
204
+ """
205
+ cls_id = [self.cls_token_id]
206
+ sep_id = [self.sep_token_id]
207
+
208
+ if token_ids_1 is None:
209
+ return cls_id + token_ids_0 + sep_id
210
+
211
+ return cls_id + token_ids_0 + sep_id + token_ids_1 + sep_id
212
+
213
+ def get_special_tokens_mask(
214
+ self,
215
+ token_ids_0: List[int],
216
+ token_ids_1: Optional[List[int]] = None,
217
+ already_has_special_tokens: bool = False,
218
+ ) -> List[int]:
219
+ """Get a mask identifying special tokens.
220
+
221
+ Args:
222
+ token_ids_0: First sequence of token IDs
223
+ token_ids_1: Optional second sequence of token IDs
224
+ already_has_special_tokens: Whether the sequences already have special tokens
225
+
226
+ Returns:
227
+ List of 0s and 1s (1 = special token)
228
+ """
229
+ if already_has_special_tokens:
230
+ return [1 if x in [self.cls_token_id, self.sep_token_id, self.pad_token_id] else 0 for x in token_ids_0]
231
+
232
+ if token_ids_1 is None:
233
+ return [1] + [0] * len(token_ids_0) + [1]
234
+
235
+ return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
236
+
237
+ def create_token_type_ids_from_sequences(
238
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
239
+ ) -> List[int]:
240
+ """Create token type IDs for sequence pairs.
241
+
242
+ Args:
243
+ token_ids_0: First sequence of token IDs
244
+ token_ids_1: Optional second sequence of token IDs
245
+
246
+ Returns:
247
+ List of token type IDs
248
+ """
249
+ sep = [self.sep_token_id]
250
+ cls = [self.cls_token_id]
251
+
252
+ if token_ids_1 is None:
253
+ return [0] * len(cls + token_ids_0 + sep)
254
+
255
+ return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
256
+
257
+ def save_vocabulary(
258
+ self, save_directory: str, filename_prefix: Optional[str] = None
259
+ ) -> Tuple[str]:
260
+ """Save the vocabulary to a file.
261
+
262
+ Args:
263
+ save_directory: Directory to save the vocabulary
264
+ filename_prefix: Optional prefix for the filename
265
+
266
+ Returns:
267
+ Tuple containing the path to the saved vocabulary file
268
+ """
269
+ if not os.path.isdir(save_directory):
270
+ os.makedirs(save_directory, exist_ok=True)
271
+
272
+ vocab_file = os.path.join(
273
+ save_directory,
274
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
275
+ )
276
+
277
+ with open(vocab_file, "w", encoding="utf-8") as f:
278
+ json.dump(self.vocab, f, ensure_ascii=False, indent=2)
279
+
280
+ return (vocab_file,)
281
+
282
+ @property
283
+ def mask_token_id(self) -> int:
284
+ """Return the mask token ID."""
285
+ return self.vocab.get(self.mask_token, 4)
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": ["tokenization_helmbert.HELMBertTokenizer", null]
4
+ },
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": " ",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "@",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "\n",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "§",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "4": {
39
+ "content": "¶",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ }
46
+ },
47
+ "bos_token": "@",
48
+ "clean_up_tokenization_spaces": false,
49
+ "cls_token": "@",
50
+ "eos_token": "\n",
51
+ "extra_special_tokens": {},
52
+ "mask_token": "¶",
53
+ "model_max_length": 512,
54
+ "pad_token": " ",
55
+ "sep_token": "\n",
56
+ "tokenizer_class": "HELMBertTokenizer",
57
+ "unk_token": "§"
58
+ }
vocab.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ " ": 0,
3
+ "@": 1,
4
+ "\n": 2,
5
+ "§": 3,
6
+ "¶": 4,
7
+ "A": 5,
8
+ "R": 6,
9
+ "N": 7,
10
+ "D": 8,
11
+ "C": 9,
12
+ "E": 10,
13
+ "Q": 11,
14
+ "G": 12,
15
+ "H": 13,
16
+ "I": 14,
17
+ "L": 15,
18
+ "K": 16,
19
+ "M": 17,
20
+ "F": 18,
21
+ "P": 19,
22
+ "S": 20,
23
+ "T": 21,
24
+ "W": 22,
25
+ "Y": 23,
26
+ "V": 24,
27
+ "X": 25,
28
+ "[": 26,
29
+ "]": 27,
30
+ "{": 28,
31
+ "}": 29,
32
+ "(": 30,
33
+ ")": 31,
34
+ "$": 32,
35
+ ",": 33,
36
+ ":": 34,
37
+ "|": 35,
38
+ "-": 36,
39
+ ".": 37,
40
+ "0": 38,
41
+ "1": 39,
42
+ "2": 40,
43
+ "3": 41,
44
+ "4": 42,
45
+ "5": 43,
46
+ "6": 44,
47
+ "7": 45,
48
+ "8": 46,
49
+ "9": 47,
50
+ "B": 48,
51
+ "O": 49,
52
+ ">": 50,
53
+ "a": 51,
54
+ "b": 52,
55
+ "c": 53,
56
+ "d": 54,
57
+ "e": 55,
58
+ "f": 56,
59
+ "g": 57,
60
+ "h": 58,
61
+ "i": 59,
62
+ "l": 60,
63
+ "m": 61,
64
+ "n": 62,
65
+ "o": 63,
66
+ "p": 64,
67
+ "r": 65,
68
+ "s": 66,
69
+ "t": 67,
70
+ "u": 68,
71
+ "v": 69,
72
+ "x": 70,
73
+ "y": 71,
74
+ "z": 72,
75
+ "/": 73,
76
+ "*": 74,
77
+ "\t": 75,
78
+ "&": 76,
79
+ "_": 77
80
+ }