Victor1306 commited on
Commit
7d1f690
·
verified ·
1 Parent(s): 99ca0bd

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. model/__init__.py +0 -0
  2. model/model.py +350 -0
model/__init__.py ADDED
File without changes
model/model.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List
4
+ from dataclasses import dataclass
5
+ from torch.nn.attention.flex_attention import create_block_mask
6
+ from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel
7
+ from transformers.modeling_outputs import ModelOutput
8
+
9
+ from model.attention import SelfAttention, MultiHeadPAttention
10
+ from model.utils import norm, MLP
11
+
12
+
13
+ @dataclass
14
+ class PLMConfig(PretrainedConfig):
15
+ def __init__(
16
+ self,
17
+ hidden_size: int = 512,
18
+ num_attention_heads: int = 8,
19
+ num_hidden_layers: int = 12,
20
+ num_att_tokens: int = 512,
21
+ vocab_size: int = 33,
22
+ expansion_ratio: float = 2.0,
23
+ attention_soft_cap: float = 64.0,
24
+ add_att_soft_cap: bool = True,
25
+ soft_logit_cap: float = 16.0,
26
+ sliding_window_size: int = 2048,
27
+ p_attention: bool = False,
28
+ tie_embeddings: bool = False,
29
+ unet: bool = False,
30
+ mlm: bool = False,
31
+ token_dropout: bool = True,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.hidden_size = hidden_size
36
+ self.num_attention_heads = num_attention_heads
37
+ self.num_hidden_layers = num_hidden_layers
38
+ self.num_att_tokens = num_att_tokens
39
+ self.vocab_size = vocab_size
40
+ self.expansion_ratio = expansion_ratio
41
+ self.soft_logit_cap = soft_logit_cap
42
+ self.attention_soft_cap = attention_soft_cap
43
+ self.add_att_soft_cap = add_att_soft_cap
44
+ self.sliding_window_size = sliding_window_size
45
+ self.p_attention = p_attention
46
+ self.tie_embeddings = tie_embeddings
47
+ self.unet = unet
48
+ self.mlm = mlm
49
+ self.token_dropout = token_dropout
50
+
51
+
52
+ @dataclass
53
+ class ESMOutput(ModelOutput):
54
+ loss: Optional[torch.Tensor] = None
55
+ logits: Optional[torch.Tensor] = None
56
+ last_hidden_state: Optional[torch.Tensor] = None
57
+
58
+
59
+ class ValueEmbedding(nn.Module):
60
+ def __init__(self, config: PLMConfig):
61
+ super().__init__()
62
+ self.embed = nn.ModuleList([
63
+ nn.Embedding(config.vocab_size, config.hidden_size)
64
+ for _ in range(config.num_hidden_layers // 2)
65
+ ])
66
+
67
+ def forward(self, inputs: torch.Tensor) -> List[torch.Tensor]:
68
+ ve = [emb(inputs) for emb in self.embed]
69
+ ve += reversed(ve)
70
+ return ve
71
+
72
+
73
+ class LMHead(nn.Module):
74
+ def __init__(self, hidden_size: int, vocab_size: int, soft_logit_cap: float = 30.0):
75
+ super().__init__()
76
+ self.dense = nn.Linear(hidden_size, hidden_size)
77
+ self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
78
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
79
+ self.soft_logit_cap = soft_logit_cap
80
+ self.act = nn.GELU()
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ x = self.dense(norm(x))
84
+ x = self.act(x)
85
+ x = self.decoder(x) + self.bias
86
+ return self.soft_logit_cap * torch.tanh(x / self.soft_logit_cap)
87
+
88
+
89
+ class TransformerBlock(nn.Module):
90
+ def __init__(self, config: PLMConfig):
91
+ super().__init__()
92
+ self.config = config
93
+ if config.p_attention:
94
+ self.attn = MultiHeadPAttention(config)
95
+ else:
96
+ self.attn = SelfAttention(config)
97
+ self.mlp = MLP(config)
98
+ self.unet = config.unet
99
+ if config.unet:
100
+ self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
101
+
102
+ def forward(
103
+ self,
104
+ x: torch.Tensor,
105
+ attention_mask: Optional[torch.Tensor] = None,
106
+ vi: Optional[torch.Tensor] = None,
107
+ x0: Optional[torch.Tensor] = None,
108
+ last_eos: Optional[int] = None,
109
+ **kwargs,
110
+ ) -> torch.Tensor:
111
+ if self.unet:
112
+ x = self.lambdas[0] * x + self.lambdas[1] * x0
113
+ x = x + self.attn(
114
+ x=norm(x),
115
+ attention_mask=attention_mask,
116
+ vi=vi,
117
+ last_eos=last_eos,
118
+ **kwargs,
119
+ )
120
+ else:
121
+ x = x + self.attn(
122
+ x=norm(x),
123
+ attention_mask=attention_mask,
124
+ last_eos=last_eos,
125
+ **kwargs,
126
+ )
127
+ x = x + self.mlp(norm(x))
128
+ return x
129
+
130
+
131
+ class Transformer(nn.Module):
132
+ def __init__(self, config: PLMConfig):
133
+ super().__init__()
134
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
135
+
136
+ def forward(
137
+ self,
138
+ x: torch.Tensor,
139
+ attention_mask: Optional[torch.Tensor] = None,
140
+ **kwargs,
141
+ ) -> torch.Tensor:
142
+ for layer in self.layers:
143
+ x = layer(
144
+ x=x,
145
+ attention_mask=attention_mask,
146
+ **kwargs,
147
+ )
148
+ return x
149
+
150
+
151
+ class UnetTransformer(nn.Module):
152
+ def __init__(self, config: PLMConfig):
153
+ super().__init__()
154
+ assert config.num_hidden_layers % 2 == 0
155
+ self.num_encoder_layers = config.num_hidden_layers // 2
156
+ self.num_decoder_layers = config.num_hidden_layers // 2
157
+
158
+ self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
159
+
160
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
161
+
162
+ def forward(
163
+ self,
164
+ x: torch.Tensor,
165
+ ve: List[torch.Tensor],
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ **kwargs,
168
+ ) -> torch.Tensor:
169
+ x0 = x
170
+ ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
171
+ skip_connections = []
172
+ for i in range(self.num_encoder_layers):
173
+ x = self.layers[i](
174
+ x=x,
175
+ attention_mask=attention_mask,
176
+ vi=ve_enc[i],
177
+ x0=x0,
178
+ **kwargs,
179
+ )
180
+ skip_connections.append(x)
181
+
182
+ for i in range(self.num_decoder_layers):
183
+ x = x + self.skip_weights[i] * skip_connections.pop()
184
+ x = self.layers[self.num_encoder_layers + i](
185
+ x=x,
186
+ attention_mask=attention_mask,
187
+ vi=ve_dec[i],
188
+ x0=x0,
189
+ **kwargs,
190
+ )
191
+ return x
192
+
193
+
194
+ class PLM(PreTrainedModel):
195
+ config_class = PLMConfig
196
+ def __init__(self, config: PLMConfig):
197
+ super().__init__(config)
198
+ self.config = config
199
+ self.tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
200
+ self.cls_token_id = self.tokenizer.cls_token_id
201
+ self.eos_token_id = self.tokenizer.eos_token_id
202
+ self.pad_token_id = self.tokenizer.pad_token_id
203
+ self.mask_token_id = self.tokenizer.mask_token_id
204
+ self.token_dropout = config.token_dropout
205
+
206
+ self.vocab_size = config.vocab_size
207
+ self.n_heads = config.num_attention_heads
208
+ self.sliding_window_size = config.sliding_window_size
209
+
210
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
211
+
212
+ self.unet = config.unet
213
+ if config.unet:
214
+ self.transformer = UnetTransformer(config)
215
+ self.value_embeds = ValueEmbedding(config)
216
+ else:
217
+ self.transformer = Transformer(config)
218
+
219
+ self.lm_head = LMHead(config.hidden_size, config.vocab_size, config.soft_logit_cap)
220
+ if config.tie_embeddings:
221
+ self.lm_head.decoder.weight = self.embedding.weight
222
+
223
+ self.mlm = config.mlm
224
+ self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction='mean')
225
+
226
+ def get_last_hidden_state(self, input_ids: torch.Tensor, sliding_window_size: int) -> torch.Tensor: # (l,)
227
+ docs = (input_ids == self.cls_token_id).cumsum(0)
228
+ eos_positions = (input_ids == self.eos_token_id).nonzero()
229
+ if eos_positions.numel() > 0:
230
+ last_eos = eos_positions[-1].squeeze()
231
+ else:
232
+ # If no EOS token found, use the last position of the sequence
233
+ last_eos = len(input_ids) - 1
234
+ seq_len = len(input_ids)
235
+
236
+ def doc_mask_mod(b, h, q_idx, kv_idx):
237
+ bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size
238
+ doc_mask = docs[q_idx] == docs[kv_idx]
239
+ pad_mask = (q_idx <= last_eos) & (kv_idx <= last_eos)
240
+ return bidirectional_sliding_window_mask & doc_mask & pad_mask
241
+
242
+ attention_mask = create_block_mask(
243
+ mask_mod=doc_mask_mod,
244
+ B=1,
245
+ H=self.n_heads,
246
+ Q_LEN=seq_len,
247
+ KV_LEN=seq_len,
248
+ device=input_ids.device,
249
+ )
250
+
251
+ x = self.embedding(input_ids)
252
+
253
+ if self.token_dropout:
254
+ x = x.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
255
+ real_token_count = len(input_ids[:last_eos])
256
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum().float() / real_token_count
257
+ x = (x * (1 - mask_ratio_observed)).to(x.dtype)
258
+
259
+ x = norm(x)
260
+ if self.unet:
261
+ ve = self.value_embeds(input_ids)
262
+ x = self.transformer(
263
+ x=x,
264
+ ve=ve,
265
+ attention_mask=attention_mask,
266
+ last_eos=last_eos,
267
+ )
268
+ else:
269
+ x = self.transformer(
270
+ x=x,
271
+ attention_mask=attention_mask,
272
+ last_eos=last_eos,
273
+ )
274
+ return x
275
+
276
+ def get_vector_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
277
+ docs = (input_ids == self.cls_token_id).cumsum(0)
278
+ x = self.get_last_hidden_state(input_ids)
279
+ x = x.view(-1, self.config.hidden_size) # (S, hidden_size)
280
+ # At this point, x is shape [S, hidden_size]
281
+ # We want to mean-pool across each document index.
282
+ # Convert docs to 0-based so we can do nice indexing
283
+ num_docs = docs.max().item()
284
+ doc_ids = docs - 1 # Now documents are labeled [0, 1, 2, ...]
285
+ # Mean-pool across tokens belonging to each doc
286
+ doc_embeds = []
287
+ for doc_idx in range(num_docs):
288
+ mask = (doc_ids == doc_idx)
289
+ # Collect all token embeddings for this doc and average
290
+ doc_embeds.append(x[mask].mean(dim=0))
291
+ # Stack into [num_documents, hidden_size]
292
+ return torch.stack(doc_embeds, dim=0)
293
+
294
+ def forward(
295
+ self,
296
+ input_ids: torch.Tensor,
297
+ labels: torch.Tensor,
298
+ mask_rate: torch.Tensor,
299
+ sliding_window_size: Optional[int] = None,
300
+ ) -> torch.Tensor:
301
+ if sliding_window_size is None:
302
+ sliding_window_size = self.sliding_window_size
303
+
304
+ last_hidden_state = self.get_last_hidden_state(input_ids, sliding_window_size)
305
+
306
+ lm_logits = self.lm_head(norm(last_hidden_state)) # (l, v)
307
+
308
+ loss = self.ce(
309
+ lm_logits.view(-1, self.vocab_size),
310
+ labels.view(-1).long()
311
+ )
312
+ #if self.training and not self.mlm:
313
+ # loss = loss / mask_rate
314
+
315
+ if torch.isnan(loss):
316
+ torch.set_printoptions(profile="full")
317
+ print("⚠️ NaN loss detected!")
318
+ print("Input IDs:", input_ids.detach().cpu())
319
+ print("Labels:", labels.detach().cpu())
320
+ print("Logits:", lm_logits.detach().cpu())
321
+
322
+ labels_cpu = labels.detach().cpu()
323
+ if torch.all(labels_cpu == -100):
324
+ print("⚠️ All labels are -100!")
325
+ else:
326
+ unique_labels = torch.unique(labels_cpu)
327
+ print("Unique labels present:", unique_labels)
328
+
329
+ return loss
330
+
331
+
332
+ if __name__ == "__main__":
333
+ # py -m model.model
334
+ from torchinfo import summary
335
+ config = PLMConfig(
336
+ hidden_size=768,
337
+ num_attention_heads=6,
338
+ num_hidden_layers=24,
339
+ expansion_ratio=8/3,
340
+ unet=True,
341
+ )
342
+ model = PLM(config).cuda()
343
+ summary(model)
344
+
345
+ input_ids = torch.randint(0, 33, (1, 100)).cuda()
346
+ output = model(input_ids)
347
+ print(f"loss: {output.loss}")
348
+ print(f"logits: {output.logits[0].shape}")
349
+ print(f"labels: {output.logits[1].shape}")
350
+ print(f"last_hidden_state: {output.last_hidden_state.shape}")