aaronfeller commited on
Commit
eff0d52
·
verified ·
1 Parent(s): 9a99df4

initial commit

Browse files
Files changed (5) hide show
  1. ChemPepMTR.py +154 -0
  2. __init__.py +2 -0
  3. config.json +16 -0
  4. config.py +25 -0
  5. model.safetensors +3 -0
ChemPepMTR.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchtune.modules import RotaryPositionalEmbeddings
5
+ from transformers import PreTrainedModel
6
+ from .config import model_config
7
+ from typing import Mapping
8
+ from transformers.tokenization_utils_base import BatchEncoding
9
+
10
+ class SwiGLU(nn.Module):
11
+ def __init__(self, input_dim, hidden_dim):
12
+ super().__init__()
13
+ self.linear1 = nn.Linear(input_dim, hidden_dim * 2, bias=True)
14
+ self.linear2 = nn.Linear(hidden_dim, input_dim, bias=True)
15
+ self.dropout = nn.Dropout(0.1) # Add dropout for regularization
16
+ def forward(self, x):
17
+ # x: (N, input_dim)
18
+ x1, x2 = self.linear1(x).chunk(2, dim=-1)
19
+ output = self.linear2(F.silu(x1) * x2)
20
+ return self.dropout(output)
21
+
22
+
23
+ class MultiHeadAttention(nn.Module):
24
+ def __init__(self, embed_dim, num_heads, max_seq_len):
25
+ super().__init__()
26
+ self.num_heads = num_heads
27
+ self.head_dim = embed_dim // num_heads
28
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
29
+
30
+ self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
31
+ self.rotary = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=max_seq_len)
32
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
33
+ self.dropout = nn.Dropout(0.1) # Add dropout for regularization
34
+
35
+ def forward(self, x, input_pos=None, mask=None):
36
+ B, T, C = x.shape # Batch, sequence, embedding dim
37
+
38
+ # project into queries, keys, and values
39
+ q, k, v = self.qkv_proj(x).view(B, T, 3, self.num_heads, self.head_dim).unbind(2) # (B, T, num_heads, head_dim)
40
+
41
+ # Apply rotary positional embeddings to queries and keys
42
+ q, k = self.rotary(q, input_pos=input_pos), self.rotary(k, input_pos=input_pos)
43
+
44
+ # Reshape to (B, num_heads, T, head_dim)
45
+ q = q.transpose(1, 2)
46
+ k = k.transpose(1, 2)
47
+ v = v.transpose(1, 2)
48
+
49
+ if mask is not None:
50
+ # set padding positions to -inf
51
+ mask = mask.to(dtype=torch.float32) # Ensure mask is float
52
+ mask = (1.0 - mask) * -1e9 # Convert to -inf for padding positions
53
+
54
+ # mask: (B, T) -> (B, 1, 1, T)
55
+ mask = mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
56
+ mask = mask.expand(B, 1, T, T) # expands to (batch, 1, seqlen, seqlen)
57
+
58
+ # Scaled dot-product attention
59
+ attn_output = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=mask)
60
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
61
+ attn_output = self.out_proj(attn_output)
62
+ return self.dropout(attn_output)
63
+
64
+
65
+ class UnifiedTransformerBlock(nn.Module):
66
+ def __init__(self, embed_dim, num_heads, ffn_hidden_dim, max_seq_len):
67
+ super().__init__()
68
+ self.attn_norm = nn.LayerNorm(embed_dim)
69
+ self.attn = MultiHeadAttention(embed_dim, num_heads, max_seq_len)
70
+ self.ffn_norm = nn.LayerNorm(embed_dim)
71
+ self.ffn = SwiGLU(embed_dim, ffn_hidden_dim)
72
+
73
+ def forward(self, x, input_pos=None, mask=None):
74
+ x = x + self.attn(self.attn_norm(x), input_pos=input_pos, mask=mask)
75
+ x = x + self.ffn(self.ffn_norm(x))
76
+ return x
77
+
78
+ class TransformerStack(nn.Module):
79
+ def __init__(self, num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len):
80
+ super().__init__()
81
+ self.blocks = nn.ModuleList([
82
+ UnifiedTransformerBlock(embed_dim, num_heads, ffn_hidden_dim, max_seq_len)
83
+ for _ in range(num_blocks)
84
+ ])
85
+ self.norm = nn.LayerNorm(embed_dim)
86
+
87
+ def forward(self, x, input_pos=None, mask=None):
88
+ for block in self.blocks:
89
+ x = block(x, input_pos=input_pos, mask=mask)
90
+ return self.norm(x)
91
+
92
+ class MLM_core(nn.Module):
93
+ def __init__(
94
+ self,
95
+ vocab_size: int,
96
+ embed_dim: int,
97
+ num_blocks: int,
98
+ num_heads: int,
99
+ ffn_hidden_dim: int,
100
+ output_dim: int,
101
+ max_seq_len: int,
102
+ ):
103
+ super().__init__()
104
+ self.embed = nn.Embedding(vocab_size, embed_dim)
105
+ self.transformer = TransformerStack(
106
+ num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len
107
+ )
108
+ self.sequence_head = nn.Linear(embed_dim, output_dim, bias=True)
109
+
110
+
111
+ def forward(self, ids, mask=None, pad_token_id=0, input_pos=None):
112
+ x = self.embed(ids)
113
+ x = self.transformer(x, mask=mask, input_pos=input_pos)
114
+ # generate logits for MLM
115
+ # print(f"x shape: {x.shape}") # Debugging line to check the shape of x
116
+ logits = self.sequence_head(x)
117
+ # print(f"logits shape: {logits.shape}") # Debugging line to check the shape of logits
118
+
119
+ # mean pool but remove positions that have pad tokens
120
+ mean_pool = x.masked_fill(ids.unsqueeze(-1) == pad_token_id, 0).mean(dim=1)
121
+
122
+ outputs = {
123
+ 'logits': logits,
124
+ 'last_layer': x,
125
+ 'mean_pool': mean_pool
126
+ }
127
+
128
+ return outputs
129
+
130
+ class MLM_model(PreTrainedModel): # HF-facing class name
131
+ config_class = model_config
132
+
133
+ def __init__(self, config):
134
+ super().__init__(config)
135
+ self.model = MLM_core(
136
+ vocab_size=config.vocab_size,
137
+ embed_dim=config.embed_dim,
138
+ num_blocks=config.num_blocks,
139
+ num_heads=config.num_heads,
140
+ ffn_hidden_dim=config.ffn_hidden_dim,
141
+ output_dim=config.output_dim,
142
+ max_seq_len=config.max_seq_len,
143
+ )
144
+ self.post_init() # Initialize weights and apply final processing
145
+
146
+ # if inputs are dictionary
147
+ def forward(self, x=None, **kwargs):
148
+ if isinstance(x, (BatchEncoding, Mapping)):
149
+ return self.model(x.get("input_ids"), mask=x.get("attention_mask"))
150
+
151
+ if "input_ids" in kwargs or "attention_mask" in kwargs:
152
+ return self.model(kwargs.get("input_ids"), mask=kwargs.get("attention_mask"))
153
+
154
+ return self.model(x)
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .config import model_config
2
+ from .ChemPepMTR import MLM_model
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "MLM_model",
3
+ "size": "base",
4
+ "ffn_hidden_dim": 1024,
5
+ "embed_dim": 768,
6
+ "num_heads": 12,
7
+ "num_blocks": 24,
8
+ "vocab_size": 405,
9
+ "output_dim": 405,
10
+ "max_seq_len": 2048,
11
+ "auto_map": {
12
+ "AutoConfig": "config.model_config",
13
+ "AutoModel": "ChemPepMTR.MLM_model"
14
+ },
15
+ "architectures": ["MLM_model"]
16
+ }
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class model_config(PretrainedConfig):
4
+ model_type = "MLM_model"
5
+ def __init__(
6
+ self,
7
+ ffn_hidden_dim = 1024,
8
+ embed_dim = 768,
9
+ num_heads = 12,
10
+ num_blocks = 24,
11
+ vocab_size = 405,
12
+ output_dim = 405,
13
+ max_seq_len = 2048,
14
+ size = "base",
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.ffn_hidden_dim = ffn_hidden_dim
19
+ self.embed_dim = embed_dim
20
+ self.num_heads = num_heads
21
+ self.num_blocks = num_blocks
22
+ self.vocab_size = vocab_size
23
+ self.output_dim = output_dim
24
+ self.max_seq_len = max_seq_len
25
+ self.size = size
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ee9244b84ff636080bbb7ef874d4a2d5159f70a8518f67a5065ff13fcab54e3
3
+ size 456074420