supanthadey1 commited on
Commit
b64c0a4
·
verified ·
1 Parent(s): 9b82aa2

Add missing base Bertose source file

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. SHA256SUMS +26 -1
  3. src/glycan_bert.py +303 -0
README.md CHANGED
@@ -22,6 +22,7 @@ This repository contains the contrastive Bertose checkpoint used to score ambigu
22
  - `vocab/bpe_vocabulary.json` - WURCS BPE vocabulary.
23
  - `vocab/bpe_ambiguity_tokens.json` - ambiguous BPE token map used by the resolver.
24
  - `src/multimodal_glycan_bert_v3.py` - model definition.
 
25
  - `src/wurcs_bpe_tokenizer.py` - WURCS BPE tokenizer.
26
 
27
  ## Expected Input
 
22
  - `vocab/bpe_vocabulary.json` - WURCS BPE vocabulary.
23
  - `vocab/bpe_ambiguity_tokens.json` - ambiguous BPE token map used by the resolver.
24
  - `src/multimodal_glycan_bert_v3.py` - model definition.
25
+ - `src/glycan_bert.py` - base BERT layers used by the multimodal model.
26
  - `src/wurcs_bpe_tokenizer.py` - WURCS BPE tokenizer.
27
 
28
  ## Expected Input
SHA256SUMS CHANGED
@@ -1,8 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  622368f62c23e97e9137c277eaadcc93ee3901cbb420b591422bb1c2e19689a5 ./.gitattributes
2
- 266caeb2fb9b68076343b40da91116dca0f2302f03cf28c2332b80b1a69c1758 ./README.md
3
  ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152 ./checkpoints/best_v51_contrastive_model.pt
4
  daf55c190fece0678064e41697a9545592beb1285f8aa74e595b933b9d37b4c2 ./config.json
5
  6a56e6f73b8f874470ecde6e538f3f5029ae23aa6c10559817d1c2a8b59b7c0f ./requirements.txt
 
 
 
 
6
  0d9ce16bf90242f38621d64cd974ea5679bff4c2013bea8d7bffe1b8dd120794 ./src/multimodal_glycan_bert_v3.py
7
  0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
8
  c68cd003370b2dcdb162f848f766e4e62f2653c6c38d205f8cbe53a9aabe2d74 ./vocab/bpe_ambiguity_tokens.json
 
1
+ 684888c0ebb17f374298b65ee2807526c066094c701bcc7ebbe1c1095f494fc1 ./.cache/huggingface/.gitignore
2
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/.gitattributes.lock
3
+ 3098e38608a2c2375ac1f78d4c4f52680796f4ff9c0dbaad6b4f0b110fbc7fc3 ./.cache/huggingface/upload/.gitattributes.metadata
4
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/README.md.lock
5
+ ecc75cccadd48cf2cc8d22daec846b6a760f492162ca145c4cfef3536dafcc2a ./.cache/huggingface/upload/README.md.metadata
6
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/SHA256SUMS.lock
7
+ aa2c2e921401dba265bdd190a662861cffd8ff05eaf6ae45a96a25385bd6c5e4 ./.cache/huggingface/upload/SHA256SUMS.metadata
8
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/checkpoints/best_v51_contrastive_model.pt.lock
9
+ 0bc5904fe02b6a64df35829729c29d40f0c0a795d586b10d844fbee91e6fa0e7 ./.cache/huggingface/upload/checkpoints/best_v51_contrastive_model.pt.metadata
10
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/config.json.lock
11
+ 9370200adedd2172ffd8459528e7fd47c5913bf9e791f5b731b0e16121ca3ebf ./.cache/huggingface/upload/config.json.metadata
12
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/requirements.txt.lock
13
+ fef169fb7e8af9c14c21240bb9034cd567bd18dc327ab39423d68ba3b2ee413a ./.cache/huggingface/upload/requirements.txt.metadata
14
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/src/multimodal_glycan_bert_v3.py.lock
15
+ 65dcbe6e66d8bba618e4d22209bd2e83b73b5de767b892c1bbd43db1c9326f42 ./.cache/huggingface/upload/src/multimodal_glycan_bert_v3.py.metadata
16
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/src/wurcs_bpe_tokenizer.py.lock
17
+ 28ca0e31a94c80afc124627b62a574125270a5f269bdff012fd36b465578dc82 ./.cache/huggingface/upload/src/wurcs_bpe_tokenizer.py.metadata
18
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/vocab/bpe_ambiguity_tokens.json.lock
19
+ eb200fe67e613751c0571950e9a7f22f9f44fde0f85b73a40d392189a203f465 ./.cache/huggingface/upload/vocab/bpe_ambiguity_tokens.json.metadata
20
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./.cache/huggingface/upload/vocab/bpe_vocabulary.json.lock
21
+ c00560217b399adfb341aacc38053299c7d4b33b4229e89e68275cd454bb7f5b ./.cache/huggingface/upload/vocab/bpe_vocabulary.json.metadata
22
  622368f62c23e97e9137c277eaadcc93ee3901cbb420b591422bb1c2e19689a5 ./.gitattributes
23
+ 21912ebe4c2b720eac3164c3628f37a39d6c918221c84e04b76a914fd709752d ./README.md
24
  ae468f4e8c06dc0c3848138a474dc43249aa6d14dfd0df8f58d68fcaad371152 ./checkpoints/best_v51_contrastive_model.pt
25
  daf55c190fece0678064e41697a9545592beb1285f8aa74e595b933b9d37b4c2 ./config.json
26
  6a56e6f73b8f874470ecde6e538f3f5029ae23aa6c10559817d1c2a8b59b7c0f ./requirements.txt
27
+ 789fde2ce01f83a5bb363aee29fe33809e2a7015c47c1915655c208d8beec496 ./src/__pycache__/glycan_bert.cpython-312.pyc
28
+ 9a0d7855e244b3a1ff369eba4da5303d528f067d1092fefd5a93c9db164de000 ./src/__pycache__/multimodal_glycan_bert_v3.cpython-312.pyc
29
+ 62259d1fe3d8736e57cadf8ce5a8bf24a7b73368d4d653c2e0d56ac94b94fe76 ./src/__pycache__/wurcs_bpe_tokenizer.cpython-312.pyc
30
+ b69f14c9976951325e3a0a4e8107a16126e67d410e966650f513f1f538a732bb ./src/glycan_bert.py
31
  0d9ce16bf90242f38621d64cd974ea5679bff4c2013bea8d7bffe1b8dd120794 ./src/multimodal_glycan_bert_v3.py
32
  0bc54399362945601bcfd403441fc80968d173200dd0561f57568b2053a94839 ./src/wurcs_bpe_tokenizer.py
33
  c68cd003370b2dcdb162f848f766e4e62f2653c6c38d205f8cbe53a9aabe2d74 ./vocab/bpe_ambiguity_tokens.json
src/glycan_bert.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Glycan BERT Model
3
+
4
+ Transformer-based masked language model for glycan structures.
5
+ Based on BERT/ESM2 architecture adapted for atomic glycan tokenization.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import math
11
+
12
+
13
+ class GlycanBERTConfig:
14
+ """Configuration for GlycanBERT."""
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size: int = 102,
19
+ hidden_size: int = 384,
20
+ num_hidden_layers: int = 6,
21
+ num_attention_heads: int = 6,
22
+ intermediate_size: int = 1536,
23
+ hidden_dropout_prob: float = 0.1,
24
+ attention_probs_dropout_prob: float = 0.1,
25
+ max_position_embeddings: int = 512,
26
+ layer_norm_eps: float = 1e-12,
27
+ pad_token_id: int = 0,
28
+ mask_token_id: int = 4,
29
+ initializer_range: float = 0.02
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.hidden_size = hidden_size
33
+ self.num_hidden_layers = num_hidden_layers
34
+ self.num_attention_heads = num_attention_heads
35
+ self.intermediate_size = intermediate_size
36
+ self.hidden_dropout_prob = hidden_dropout_prob
37
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
38
+ self.max_position_embeddings = max_position_embeddings
39
+ self.layer_norm_eps = layer_norm_eps
40
+ self.pad_token_id = pad_token_id
41
+ self.mask_token_id = mask_token_id
42
+ self.initializer_range = initializer_range
43
+
44
+
45
+ class GlycanBERTEmbeddings(nn.Module):
46
+ """
47
+ Embeddings for glycan tokens including token and positional embeddings.
48
+ """
49
+
50
+ def __init__(self, config: GlycanBERTConfig):
51
+ super().__init__()
52
+ self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
53
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
54
+
55
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
56
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
57
+
58
+ # position_ids (1, max_seq_len) is contiguous in memory and exported when serialized
59
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
60
+
61
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Args:
64
+ input_ids: Tensor of shape (batch_size, seq_len)
65
+
66
+ Returns:
67
+ Embeddings of shape (batch_size, seq_len, hidden_size)
68
+ """
69
+ batch_size, seq_len = input_ids.shape
70
+
71
+ # Token embeddings
72
+ token_embeds = self.token_embeddings(input_ids)
73
+
74
+ # Position embeddings
75
+ position_ids = self.position_ids[:, :seq_len]
76
+ position_embeds = self.position_embeddings(position_ids)
77
+
78
+ # Combine
79
+ embeddings = token_embeds + position_embeds
80
+ embeddings = self.LayerNorm(embeddings)
81
+ embeddings = self.dropout(embeddings)
82
+
83
+ return embeddings
84
+
85
+
86
+ class GlycanBERTAttention(nn.Module):
87
+ """Multi-head self-attention."""
88
+
89
+ def __init__(self, config: GlycanBERTConfig):
90
+ super().__init__()
91
+ assert config.hidden_size % config.num_attention_heads == 0
92
+
93
+ self.num_attention_heads = config.num_attention_heads
94
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
95
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
96
+
97
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
98
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
99
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
100
+
101
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
102
+
103
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
104
+ """Reshape for multi-head attention."""
105
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
106
+ x = x.view(*new_shape)
107
+ return x.permute(0, 2, 1, 3) # (batch, heads, seq_len, head_size)
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.Tensor,
112
+ attention_mask: torch.Tensor = None
113
+ ) -> torch.Tensor:
114
+ """
115
+ Args:
116
+ hidden_states: (batch_size, seq_len, hidden_size)
117
+ attention_mask: (batch_size, seq_len) - 1 for valid, 0 for padding
118
+
119
+ Returns:
120
+ Attention output: (batch_size, seq_len, hidden_size)
121
+ """
122
+ batch_size, seq_len, _ = hidden_states.shape
123
+
124
+ # Linear projections
125
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
126
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
127
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
128
+
129
+ # Attention scores
130
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
131
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
132
+
133
+ # Apply attention mask
134
+ if attention_mask is not None:
135
+ # Convert mask to additive mask
136
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # (batch, 1, 1, seq_len)
137
+ attention_mask = (1.0 - attention_mask) * -10000.0
138
+ attention_scores = attention_scores + attention_mask
139
+
140
+ # Attention probabilities
141
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
142
+ attention_probs = self.dropout(attention_probs)
143
+
144
+ # Apply attention to values
145
+ context_layer = torch.matmul(attention_probs, value_layer)
146
+
147
+ # Reshape back
148
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
149
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
150
+ context_layer = context_layer.view(*new_shape)
151
+
152
+ return context_layer
153
+
154
+
155
+ class GlycanBERTLayer(nn.Module):
156
+ """Single transformer layer."""
157
+
158
+ def __init__(self, config: GlycanBERTConfig):
159
+ super().__init__()
160
+ self.attention = GlycanBERTAttention(config)
161
+ self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
162
+ self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
163
+
164
+ self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
165
+ self.output = nn.Linear(config.intermediate_size, config.hidden_size)
166
+ self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
167
+
168
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ attention_mask: torch.Tensor = None
174
+ ) -> torch.Tensor:
175
+ """
176
+ Args:
177
+ hidden_states: (batch_size, seq_len, hidden_size)
178
+ attention_mask: (batch_size, seq_len)
179
+
180
+ Returns:
181
+ Output: (batch_size, seq_len, hidden_size)
182
+ """
183
+ # Self-attention
184
+ attention_output = self.attention(hidden_states, attention_mask)
185
+ attention_output = self.attention_output(attention_output)
186
+ attention_output = self.dropout(attention_output)
187
+
188
+ # Add & Norm
189
+ hidden_states = self.attention_layer_norm(hidden_states + attention_output)
190
+
191
+ # Feed-forward
192
+ intermediate_output = self.intermediate(hidden_states)
193
+ intermediate_output = nn.functional.gelu(intermediate_output)
194
+
195
+ layer_output = self.output(intermediate_output)
196
+ layer_output = self.dropout(layer_output)
197
+
198
+ # Add & Norm
199
+ layer_output = self.output_layer_norm(hidden_states + layer_output)
200
+
201
+ return layer_output
202
+
203
+
204
+ class GlycanBERT(nn.Module):
205
+ """
206
+ Glycan BERT model for masked language modeling.
207
+ """
208
+
209
+ def __init__(self, config: GlycanBERTConfig):
210
+ super().__init__()
211
+ self.config = config
212
+
213
+ # Embeddings
214
+ self.embeddings = GlycanBERTEmbeddings(config)
215
+
216
+ # Transformer layers
217
+ self.layers = nn.ModuleList([GlycanBERTLayer(config) for _ in range(config.num_hidden_layers)])
218
+
219
+ # MLM head
220
+ self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
221
+
222
+ # Initialize weights
223
+ self.apply(self._init_weights)
224
+
225
+ def _init_weights(self, module):
226
+ """Initialize weights."""
227
+ if isinstance(module, nn.Linear):
228
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
229
+ if module.bias is not None:
230
+ module.bias.data.zero_()
231
+ elif isinstance(module, nn.Embedding):
232
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
233
+ if module.padding_idx is not None:
234
+ module.weight.data[module.padding_idx].zero_()
235
+ elif isinstance(module, nn.LayerNorm):
236
+ module.bias.data.zero_()
237
+ module.weight.data.fill_(1.0)
238
+
239
+ def forward(
240
+ self,
241
+ input_ids: torch.Tensor,
242
+ attention_mask: torch.Tensor = None,
243
+ labels: torch.Tensor = None
244
+ ):
245
+ """
246
+ Args:
247
+ input_ids: (batch_size, seq_len)
248
+ attention_mask: (batch_size, seq_len) - 1 for valid, 0 for padding
249
+ labels: (batch_size, seq_len) - token IDs to predict, -100 for positions to ignore
250
+
251
+ Returns:
252
+ If labels provided: (loss, logits)
253
+ Else: logits
254
+ """
255
+ # Create attention mask if not provided
256
+ if attention_mask is None:
257
+ attention_mask = (input_ids != self.config.pad_token_id).float()
258
+
259
+ # Embeddings
260
+ hidden_states = self.embeddings(input_ids)
261
+
262
+ # Transformer layers
263
+ for layer in self.layers:
264
+ hidden_states = layer(hidden_states, attention_mask)
265
+
266
+ # MLM prediction
267
+ logits = self.mlm_head(hidden_states)
268
+
269
+ # Calculate loss if labels provided
270
+ loss = None
271
+ if labels is not None:
272
+ loss_fct = nn.CrossEntropyLoss() # -100 is ignored
273
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
274
+
275
+ if loss is not None:
276
+ return loss, logits
277
+ return logits
278
+
279
+ def get_embeddings(
280
+ self,
281
+ input_ids: torch.Tensor,
282
+ attention_mask: torch.Tensor = None
283
+ ) -> torch.Tensor:
284
+ """
285
+ Get contextualized embeddings (for downstream tasks).
286
+
287
+ Args:
288
+ input_ids: (batch_size, seq_len)
289
+ attention_mask: (batch_size, seq_len)
290
+
291
+ Returns:
292
+ Embeddings: (batch_size, seq_len, hidden_size)
293
+ """
294
+ if attention_mask is None:
295
+ attention_mask = (input_ids != self.config.pad_token_id).float()
296
+
297
+ hidden_states = self.embeddings(input_ids)
298
+
299
+ for layer in self.layers:
300
+ hidden_states = layer(hidden_states, attention_mask)
301
+
302
+ return hidden_states
303
+