andrecornman commited on
Commit
2eddd0d
·
verified ·
1 Parent(s): 414cb8c

Upload FlashPPI model

Browse files
__KMP_REGISTERED_LIB_23805 ADDED
Binary file (1.02 kB). View file
 
__KMP_REGISTERED_LIB_91112 ADDED
Binary file (1.02 kB). View file
 
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FlashPPIModel"
4
+ ],
5
+ "clip_embed_dim": 1024,
6
+ "contact_embed_dim": 1280,
7
+ "contact_num_heads": 8,
8
+ "contact_transformer_depth": 2,
9
+ "dtype": "float32",
10
+ "max_position_embeddings": 512,
11
+ "model_type": "flashppi",
12
+ "plm_depth": 33,
13
+ "plm_dim": 1280,
14
+ "plm_ffn_dim_multiplier": null,
15
+ "plm_heads": 20,
16
+ "plm_norm_eps": 1e-05,
17
+ "plm_swiglu_multiple_of": 256,
18
+ "plm_vocab_size": 37,
19
+ "transformers_version": "4.57.1",
20
+ "use_flash_attention": true,
21
+ "auto_map": {
22
+ "AutoConfig": "configuration_flashppi.FlashPPIConfig",
23
+ "AutoModel": "modeling_flashppi.FlashPPIModel"
24
+ }
25
+ }
configuration_flashppi.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashPPI model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class FlashPPIConfig(PretrainedConfig):
7
+
8
+ model_type = "flashppi"
9
+
10
+ def __init__(
11
+ self,
12
+ # gLM2 backbone config (defaults match gLM2_650M)
13
+ plm_dim: int = 1280,
14
+ plm_depth: int = 33,
15
+ plm_heads: int = 20,
16
+ plm_vocab_size: int = 37,
17
+ plm_norm_eps: float = 1e-5,
18
+ plm_swiglu_multiple_of: int = 256,
19
+ plm_ffn_dim_multiplier: float = None,
20
+ # FlashPPI head config
21
+ clip_embed_dim: int = 1024,
22
+ contact_embed_dim: int = 1280,
23
+ contact_num_heads: int = 8,
24
+ contact_transformer_depth: int = 2,
25
+ max_position_embeddings: int = 512,
26
+ use_flash_attention: bool = True,
27
+ **kwargs
28
+ ):
29
+ super().__init__(**kwargs)
30
+ # gLM2 config
31
+ self.plm_dim = plm_dim
32
+ self.plm_depth = plm_depth
33
+ self.plm_heads = plm_heads
34
+ self.plm_vocab_size = plm_vocab_size
35
+ self.plm_norm_eps = plm_norm_eps
36
+ self.plm_swiglu_multiple_of = plm_swiglu_multiple_of
37
+ self.plm_ffn_dim_multiplier = plm_ffn_dim_multiplier
38
+ # FlashPPI config
39
+ self.clip_embed_dim = clip_embed_dim
40
+ self.contact_embed_dim = contact_embed_dim
41
+ self.contact_num_heads = contact_num_heads
42
+ self.contact_transformer_depth = contact_transformer_depth
43
+ self.max_position_embeddings = max_position_embeddings
44
+ self.use_flash_attention = use_flash_attention
glm_tokenizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from tokenizers.models import BPE
3
+ from transformers import PreTrainedTokenizerFast
4
+
5
+
6
+ class gLM2Tokenizer(PreTrainedTokenizerFast):
7
+
8
+ VOCAB = [
9
+ "<cls>", "<pad>", "<eos>", "<unk>",
10
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
11
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
12
+ "O", "a", "t", "c", "g", "<+>", "<->", "<mask>", "<sep>",
13
+ ]
14
+
15
+ def __init__(
16
+ self,
17
+ unk_token="<unk>",
18
+ cls_token="<cls>",
19
+ pad_token="<pad>",
20
+ mask_token="<mask>",
21
+ eos_token="<eos>",
22
+ sep_token="<sep>",
23
+ pos_token="<+>",
24
+ neg_token="<->",
25
+ **kwargs,
26
+ ):
27
+ all_tokens = self.VOCAB
28
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
29
+
30
+ bpe = BPE(token_to_id, merges=[], unk_token=str(unk_token))
31
+ tokenizer = Tokenizer(bpe)
32
+ special_tokens = [cls_token, pad_token,
33
+ mask_token, eos_token, sep_token, pos_token, neg_token]
34
+
35
+ tokenizer.add_special_tokens(
36
+ special_tokens,
37
+ )
38
+
39
+ super().__init__(
40
+ tokenizer_object=tokenizer,
41
+ unk_token=unk_token,
42
+ cls_token=cls_token,
43
+ pad_token=pad_token,
44
+ mask_token=mask_token,
45
+ eos_token=eos_token,
46
+ sep_token=sep_token,
47
+ **kwargs,
48
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:783abc99f0d39c350d9be2e553dbf407b9ebc0fa1d288b31f418b4a3ef223f2c
3
+ size 2931379208
modeling_flashppi.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, Union
6
+ from einops import rearrange, repeat
7
+ from torch.utils.checkpoint import checkpoint
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import ModelOutput
10
+
11
+ from .configuration_flashppi import FlashPPIConfig
12
+
13
+ # Detect Flash Attention installation
14
+ try:
15
+ from flash_attn.ops.activations import swiglu
16
+ from flash_attn.layers.rotary import apply_rotary_emb_func
17
+ from flash_attn import flash_attn_varlen_kvpacked_func
18
+ from flash_attn.bert_padding import pad_input, unpad_input
19
+ from flash_attn.ops.triton.layer_norm import RMSNorm
20
+ FLASH_ATTN_AVAILABLE = True
21
+ except ImportError:
22
+ FLASH_ATTN_AVAILABLE = False
23
+ unpad_input = pad_input = apply_rotary_emb_func = None
24
+ flash_attn_varlen_kvpacked_func = None
25
+
26
+ def swiglu(x, y):
27
+ return F.silu(x) * y
28
+
29
+ class RMSNorm(nn.Module):
30
+ """RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
31
+ def __init__(self, dim, eps=1e-6):
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(dim))
34
+ self.eps = eps
35
+
36
+ def forward(self, hidden_states):
37
+ input_dtype = hidden_states.dtype
38
+ hidden_states = hidden_states.to(torch.float32)
39
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
40
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
41
+ return (self.weight * hidden_states).to(input_dtype)
42
+
43
+
44
+ @dataclass
45
+ class FlashPPIOutput(ModelOutput):
46
+ """Output type for FlashPPI model.
47
+
48
+ Args:
49
+ contact_map: (B, L1, L2) contact probabilities between residue pairs.
50
+ contact_score: (B,) maximum contact probability per pair.
51
+ clip_embed1: (B, D) CLIP embedding for first protein.
52
+ clip_embed2: (B, D) CLIP embedding for second protein.
53
+ clip_score: (B,) CLIP similarity score (cosine similarity).
54
+ """
55
+ contact_map: Optional[torch.FloatTensor] = None
56
+ contact_score: Optional[torch.FloatTensor] = None
57
+ clip_embed1: Optional[torch.FloatTensor] = None
58
+ clip_embed2: Optional[torch.FloatTensor] = None
59
+ clip_score: Optional[torch.FloatTensor] = None
60
+
61
+
62
+ def rotate_half(x, interleaved=False):
63
+ if not interleaved:
64
+ x1, x2 = x.chunk(2, dim=-1)
65
+ return torch.cat((-x2, x1), dim=-1)
66
+ else:
67
+ x1, x2 = x[..., ::2], x[..., 1::2]
68
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
69
+
70
+
71
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, position_ids=None):
72
+ """Apply rotary embeddings using pure PyTorch."""
73
+ if position_ids is not None:
74
+ cos = cos[position_ids]
75
+ sin = sin[position_ids]
76
+ else:
77
+ cos = cos[:x.shape[1]]
78
+ sin = sin[:x.shape[1]]
79
+
80
+ if not interleaved:
81
+ cos = repeat(cos, "... d -> ... 1 (2 d)")
82
+ sin = repeat(sin, "... d -> ... 1 (2 d)")
83
+ else:
84
+ cos = repeat(cos, "... d -> ... 1 (d 2)")
85
+ sin = repeat(sin, "... d -> ... 1 (d 2)")
86
+
87
+ ro_dim = cos.shape[-1]
88
+ return torch.cat([
89
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
90
+ x[..., ro_dim:],
91
+ ], dim=-1)
92
+
93
+
94
+ class RotaryEmbedding(nn.Module):
95
+ """Rotary position embeddings with flash attention support."""
96
+
97
+ def __init__(self, dim: int, base: float = 10000.0, interleaved: bool = False, device=None):
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.base = float(base)
101
+ self.interleaved = interleaved
102
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
103
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
104
+ self._seq_len_cached = 0
105
+ self._cos_cached = None
106
+ self._sin_cached = None
107
+
108
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
109
+ if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device:
110
+ self._seq_len_cached = seqlen
111
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
112
+ freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32))
113
+ self._cos_cached = torch.cos(freqs).to(dtype)
114
+ self._sin_cached = torch.sin(freqs).to(dtype)
115
+
116
+ def forward(
117
+ self,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ cu_seqlens: Optional[torch.Tensor] = None,
121
+ max_seqlen: Optional[int] = None,
122
+ position_ids: Optional[torch.Tensor] = None,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ seqlen = q.shape[1] if max_seqlen is None else max_seqlen
125
+ self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
126
+
127
+ if FLASH_ATTN_AVAILABLE and cu_seqlens is not None:
128
+ q = apply_rotary_emb_func(
129
+ q, self._cos_cached, self._sin_cached,
130
+ interleaved=self.interleaved, inplace=True,
131
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
132
+ )
133
+ k = apply_rotary_emb_func(
134
+ k, self._cos_cached, self._sin_cached,
135
+ interleaved=self.interleaved, inplace=True,
136
+ cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
137
+ )
138
+ else:
139
+ q = apply_rotary_emb_torch(q, self._cos_cached, self._sin_cached, self.interleaved, position_ids)
140
+ k = apply_rotary_emb_torch(k, self._cos_cached, self._sin_cached, self.interleaved, position_ids)
141
+ return q, k
142
+
143
+
144
+ class Attention(nn.Module):
145
+ """Multi-head attention with optional flash attention."""
146
+
147
+ def __init__(self, dim: int, num_heads: int, use_rope: bool = True):
148
+ super().__init__()
149
+ self.n_heads = num_heads
150
+ self.head_dim = dim // num_heads
151
+ self.wqkv = nn.Linear(dim, num_heads * self.head_dim * 3, bias=False)
152
+ self.wo = nn.Linear(num_heads * self.head_dim, dim, bias=False)
153
+ self.rotary_emb = RotaryEmbedding(self.head_dim) if use_rope else None
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ cu_seqlens: Optional[torch.Tensor] = None,
159
+ max_seq_len: Optional[int] = None,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ position_ids: Optional[torch.Tensor] = None,
162
+ ) -> torch.Tensor:
163
+ qkv = self.wqkv(x)
164
+
165
+ if cu_seqlens is not None and FLASH_ATTN_AVAILABLE:
166
+ # Flash attention path (unpadded)
167
+ total_seqlen = x.shape[0]
168
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
169
+ q = q.view(total_seqlen, self.n_heads, self.head_dim)
170
+ k = k.view(total_seqlen, self.n_heads, self.head_dim)
171
+ v = v.view(total_seqlen, self.n_heads, self.head_dim)
172
+
173
+ if self.rotary_emb is not None:
174
+ q, k = self.rotary_emb(q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
175
+
176
+ kv = torch.stack([k, v], 1)
177
+ output = flash_attn_varlen_kvpacked_func(
178
+ q, kv,
179
+ cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
180
+ max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len,
181
+ dropout_p=0.0, causal=False,
182
+ )
183
+ output = output.view(total_seqlen, self.n_heads * self.head_dim)
184
+ else:
185
+ # SDPA path (padded)
186
+ bsz, seqlen, _ = x.shape
187
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
188
+ q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
189
+ k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
190
+ v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
191
+
192
+ if self.rotary_emb is not None:
193
+ q, k = self.rotary_emb(q, k, position_ids=position_ids)
194
+
195
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
196
+
197
+ attn_mask = None
198
+ if attention_mask is not None:
199
+ attn_mask = attention_mask.unsqueeze(1).unsqueeze(2).bool()
200
+
201
+ output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
202
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_heads * self.head_dim)
203
+
204
+ return self.wo(output)
205
+
206
+
207
+ class FeedForward(nn.Module):
208
+ """SwiGLU feedforward network."""
209
+
210
+ def __init__(self, dim: int, hidden_mult: float = 4.0, multiple_of: int = 256, ffn_dim_multiplier: float = None):
211
+ super().__init__()
212
+ hidden_dim = int(2 * dim * hidden_mult / 3)
213
+ if ffn_dim_multiplier is not None:
214
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
215
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
216
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
217
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
218
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
219
+
220
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
221
+ return self.w2(swiglu(self.w1(x), self.w3(x)))
222
+
223
+
224
+ class TransformerBlock(nn.Module):
225
+ """Pre-norm transformer block."""
226
+
227
+ def __init__(self, dim: int, num_heads: int, norm_eps: float = 1e-6,
228
+ multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True):
229
+ super().__init__()
230
+ self.attention = Attention(dim, num_heads, use_rope)
231
+ self.feed_forward = FeedForward(dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier)
232
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
233
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
234
+
235
+ def forward(
236
+ self,
237
+ x: torch.Tensor,
238
+ cu_seqlens: Optional[torch.Tensor] = None,
239
+ max_seq_len: Optional[int] = None,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ position_ids: Optional[torch.Tensor] = None,
242
+ ) -> torch.Tensor:
243
+ h = x + self.attention(self.attention_norm(x), cu_seqlens, max_seq_len, attention_mask, position_ids)
244
+ return h + self.feed_forward(self.ffn_norm(h))
245
+
246
+
247
+ class TransformerLayers(nn.Module):
248
+ """Stack of transformer blocks with optional flash attention."""
249
+
250
+ def __init__(self, dim: int, num_heads: int, depth: int, norm_eps: float = 1e-6,
251
+ multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True):
252
+ super().__init__()
253
+ self.dim = dim
254
+ self.layers = nn.ModuleList([
255
+ TransformerBlock(dim, num_heads, norm_eps, multiple_of, ffn_dim_multiplier, use_rope)
256
+ for _ in range(depth)
257
+ ])
258
+ self.gradient_checkpointing = False
259
+
260
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
261
+ batch_size, seq_len = x.shape[:2]
262
+ cu_seqlens, max_seq_len_in_batch, indices, position_ids = None, None, None, None
263
+
264
+ if FLASH_ATTN_AVAILABLE and attention_mask is not None and not attention_mask.all():
265
+ x, indices, cu_seqlens, max_seq_len_in_batch, _ = unpad_input(x, attention_mask)
266
+ mask_for_layers = None
267
+ elif attention_mask is not None:
268
+ mask_long = attention_mask.long()
269
+ position_ids = (mask_long.cumsum(dim=1) - 1).clamp(min=0)
270
+ mask_for_layers = attention_mask
271
+ else:
272
+ mask_for_layers = None
273
+
274
+ for layer in self.layers:
275
+ if self.training and self.gradient_checkpointing:
276
+ x = checkpoint(layer, x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids, use_reentrant=False)
277
+ else:
278
+ x = layer(x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids)
279
+
280
+ if FLASH_ATTN_AVAILABLE and indices is not None:
281
+ x = pad_input(x, indices, batch_size, seq_len)
282
+
283
+ return x
284
+
285
+
286
+ class GLM2Backbone(nn.Module):
287
+ """gLM2 protein language model backbone."""
288
+
289
+ def __init__(self, config: FlashPPIConfig):
290
+ super().__init__()
291
+ self.config = config
292
+ self.tok_embeddings = nn.Embedding(config.plm_vocab_size, config.plm_dim)
293
+ self.encoder = TransformerLayers(
294
+ dim=config.plm_dim,
295
+ num_heads=config.plm_heads,
296
+ depth=config.plm_depth,
297
+ norm_eps=config.plm_norm_eps,
298
+ multiple_of=config.plm_swiglu_multiple_of,
299
+ ffn_dim_multiplier=config.plm_ffn_dim_multiplier,
300
+ use_rope=True,
301
+ )
302
+
303
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
304
+ h = self.tok_embeddings(input_ids)
305
+ return self.encoder(h, attention_mask)
306
+
307
+
308
+ class MLPHead(nn.Module):
309
+ """SwiGLU MLP projection head."""
310
+
311
+ def __init__(self, in_dim: int, out_dim: int, hidden_mult: float = 2.0):
312
+ super().__init__()
313
+ hidden_dim = int(in_dim * hidden_mult)
314
+ self.w1 = nn.Linear(in_dim, hidden_dim, bias=False)
315
+ self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
316
+ self.w3 = nn.Linear(in_dim, hidden_dim, bias=False)
317
+
318
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
319
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
320
+
321
+
322
+ class ContrastiveHead(nn.Module):
323
+ """CLIP-style contrastive head with mean pooling."""
324
+
325
+ def __init__(self, hidden_dim: int, embed_dim: int):
326
+ super().__init__()
327
+ self.head = MLPHead(hidden_dim, embed_dim)
328
+
329
+ def forward(self, residue_embeds: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
330
+ mask = mask.unsqueeze(-1).bool()
331
+ embeds = torch.where(mask, residue_embeds, 0.0)
332
+ embeds = embeds.sum(dim=1) / mask.sum(dim=1).float().clamp(min=1.0)
333
+ return F.normalize(self.head(embeds), dim=-1)
334
+
335
+
336
+ class ContactHead(nn.Module):
337
+ """Contact prediction head using cross-attention between protein pairs."""
338
+
339
+ def __init__(self, input_dim: int, contact_dim: int, num_heads: int = 8, depth: int = 2):
340
+ super().__init__()
341
+ self.num_heads = num_heads
342
+ self.head_dim = contact_dim // num_heads
343
+ assert contact_dim % num_heads == 0
344
+
345
+ self.segment_embed = nn.Embedding(2, input_dim)
346
+ nn.init.normal_(self.segment_embed.weight, std=0.02)
347
+
348
+ self.transformer = TransformerLayers(input_dim, num_heads, depth, use_rope=True)
349
+ self.norm = nn.LayerNorm(input_dim)
350
+ self.q_proj = nn.Linear(input_dim, contact_dim, bias=True)
351
+ self.k_proj = nn.Linear(input_dim, contact_dim, bias=True)
352
+ self.output_mix = nn.Linear(num_heads, 1)
353
+ nn.init.constant_(self.output_mix.bias, -3.0)
354
+ self.scale = self.head_dim ** -0.5
355
+
356
+ def forward(
357
+ self,
358
+ embed1: torch.Tensor,
359
+ embed2: torch.Tensor,
360
+ mask1: torch.Tensor,
361
+ mask2: torch.Tensor,
362
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
363
+ B, L1, D = embed1.shape
364
+ _, L2, _ = embed2.shape
365
+
366
+ seg1 = self.segment_embed(torch.zeros(L1, device=embed1.device, dtype=torch.long))
367
+ seg2 = self.segment_embed(torch.ones(L2, device=embed1.device, dtype=torch.long))
368
+
369
+ x = torch.cat([embed1 + seg1.unsqueeze(0), embed2 + seg2.unsqueeze(0)], dim=1)
370
+ combined_mask = torch.cat([mask1, mask2], dim=1).bool() if mask1 is not None and mask2 is not None else None
371
+
372
+ x = self.transformer(x, attention_mask=combined_mask)
373
+
374
+ embed1 = self.norm(x[:, :L1, :])
375
+ embed2 = self.norm(x[:, L1:, :])
376
+
377
+ q = self.q_proj(embed1).view(B, L1, self.num_heads, self.head_dim).transpose(1, 2)
378
+ k = self.k_proj(embed2).view(B, L2, self.num_heads, self.head_dim).transpose(1, 2)
379
+
380
+ attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale
381
+ attn_logits = attn_logits.permute(0, 2, 3, 1).contiguous()
382
+ contact_logits = self.output_mix(attn_logits).squeeze(-1)
383
+
384
+ if mask1 is not None and mask2 is not None:
385
+ valid_mask = (mask1.unsqueeze(2) * mask2.unsqueeze(1)).bool()
386
+ else:
387
+ valid_mask = torch.ones_like(contact_logits, dtype=torch.bool)
388
+
389
+ return contact_logits, valid_mask
390
+
391
+
392
+ class FlashPPIPreTrainedModel(PreTrainedModel):
393
+ """Base class for FlashPPI models."""
394
+
395
+ config_class = FlashPPIConfig
396
+ base_model_prefix = "flashppi"
397
+ supports_gradient_checkpointing = True
398
+
399
+ def _init_weights(self, module):
400
+ if isinstance(module, nn.Linear):
401
+ nn.init.normal_(module.weight, std=0.02)
402
+ if module.bias is not None:
403
+ nn.init.zeros_(module.bias)
404
+ elif isinstance(module, nn.Embedding):
405
+ nn.init.normal_(module.weight, std=0.02)
406
+ elif isinstance(module, RotaryEmbedding):
407
+ # Re-calculate the frequencies using the module's stored attributes
408
+ inv_freq = 1.0 / (
409
+ module.base
410
+ ** (
411
+ torch.arange(0, module.dim, 2, device=module.inv_freq.device, dtype=torch.float32)
412
+ / module.dim
413
+ )
414
+ )
415
+ # Force the buffer to update
416
+ with torch.no_grad():
417
+ module.inv_freq.copy_(inv_freq)
418
+
419
+ class FlashPPIModel(FlashPPIPreTrainedModel):
420
+ """FlashPPI model."""
421
+
422
+ def __init__(self, config: FlashPPIConfig):
423
+ super().__init__(config)
424
+ self.config = config
425
+
426
+ # gLM2 backbone
427
+ self.plm = GLM2Backbone(config)
428
+
429
+ # CLIP heads (asymmetric for query/key)
430
+ self.head_q = ContrastiveHead(config.plm_dim, config.clip_embed_dim)
431
+ self.head_k = ContrastiveHead(config.plm_dim, config.clip_embed_dim)
432
+ self.logit_scale = nn.Parameter(torch.ones([]) * 2.6593) # ln(1/0.07)
433
+
434
+ # Contact prediction head
435
+ self.contact_head = ContactHead(
436
+ config.plm_dim,
437
+ config.contact_embed_dim,
438
+ num_heads=config.contact_num_heads,
439
+ depth=config.contact_transformer_depth,
440
+ )
441
+
442
+ self.post_init()
443
+
444
+ def encode_protein(
445
+ self,
446
+ input_ids: torch.Tensor,
447
+ attention_mask: Optional[torch.Tensor] = None,
448
+ ) -> torch.Tensor:
449
+ """Encode a protein sequence to residue-level embeddings.
450
+
451
+ Args:
452
+ input_ids: (B, L) token IDs from gLM2 tokenizer.
453
+ attention_mask: (B, L) attention mask.
454
+
455
+ Returns:
456
+ (B, L, plm_dim) residue embeddings.
457
+ """
458
+ return self.plm(input_ids, attention_mask)
459
+
460
+ def predict_contacts(
461
+ self,
462
+ embed1: torch.Tensor,
463
+ embed2: torch.Tensor,
464
+ mask1: torch.Tensor,
465
+ mask2: torch.Tensor,
466
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
467
+ """Predict contact map from pre-computed residue embeddings.
468
+
469
+ This method is useful for efficient 2-stage inference where embeddings
470
+ are pre-computed and cached.
471
+
472
+ Args:
473
+ embed1: (B, L1, D) residue embeddings for protein 1.
474
+ embed2: (B, L2, D) residue embeddings for protein 2.
475
+ mask1: (B, L1) attention mask for protein 1.
476
+ mask2: (B, L2) attention mask for protein 2.
477
+
478
+ Returns:
479
+ contact_logits: (B, L1, L2) raw logits.
480
+ valid_mask: (B, L1, L2) mask for valid positions.
481
+ """
482
+ return self.contact_head(embed1, embed2, mask1, mask2)
483
+
484
+ def forward(
485
+ self,
486
+ input_ids1: torch.Tensor,
487
+ input_ids2: torch.Tensor,
488
+ attention_mask1: Optional[torch.Tensor] = None,
489
+ attention_mask2: Optional[torch.Tensor] = None,
490
+ return_dict: bool = True,
491
+ ) -> Union[Tuple, FlashPPIOutput]:
492
+ """Forward pass for protein pair interaction prediction.
493
+
494
+ Args:
495
+ input_ids1: (B, L1) token IDs for protein 1.
496
+ input_ids2: (B, L2) token IDs for protein 2.
497
+ attention_mask1: (B, L1) attention mask for protein 1.
498
+ attention_mask2: (B, L2) attention mask for protein 2.
499
+ return_dict: Whether to return a FlashPPIOutput or tuple.
500
+
501
+ Returns:
502
+ FlashPPIOutput with contact predictions and CLIP embeddings.
503
+ """
504
+ B = input_ids1.shape[0]
505
+ L1, L2 = input_ids1.shape[1], input_ids2.shape[1]
506
+
507
+ if attention_mask1 is None:
508
+ attention_mask1 = torch.ones_like(input_ids1)
509
+ if attention_mask2 is None:
510
+ attention_mask2 = torch.ones_like(input_ids2)
511
+
512
+ # Encode both proteins in a single batched PLM call for efficiency
513
+ # Pad to same length if needed
514
+ if L1 != L2:
515
+ max_len = max(L1, L2)
516
+ if L1 < max_len:
517
+ pad_len = max_len - L1
518
+ input_ids1 = F.pad(input_ids1, (0, pad_len), value=0)
519
+ attention_mask1 = F.pad(attention_mask1, (0, pad_len), value=0)
520
+ if L2 < max_len:
521
+ pad_len = max_len - L2
522
+ input_ids2 = F.pad(input_ids2, (0, pad_len), value=0)
523
+ attention_mask2 = F.pad(attention_mask2, (0, pad_len), value=0)
524
+
525
+ # Batch both sequences for single PLM forward pass
526
+ batched_input_ids = torch.cat([input_ids1, input_ids2], dim=0)
527
+ batched_attention_mask = torch.cat([attention_mask1, attention_mask2], dim=0)
528
+ batched_embeds = self.encode_protein(batched_input_ids, batched_attention_mask)
529
+
530
+ # Split and trim back to original lengths
531
+ residue_embeds1 = batched_embeds[:B, :L1, :]
532
+ residue_embeds2 = batched_embeds[B:, :L2, :]
533
+ attention_mask1 = attention_mask1[:, :L1]
534
+ attention_mask2 = attention_mask2[:, :L2]
535
+
536
+ # Contrastive embeddings
537
+ clip_embed1 = self.head_q(residue_embeds1, attention_mask1)
538
+ clip_embed2 = self.head_k(residue_embeds2, attention_mask2)
539
+ clip_score = (clip_embed1 * clip_embed2).sum(dim=-1)
540
+
541
+ # Contact prediction
542
+ contact_logits, valid_mask = self.contact_head(
543
+ residue_embeds1, residue_embeds2, attention_mask1, attention_mask2
544
+ )
545
+ contact_map = torch.sigmoid(contact_logits)
546
+
547
+ # Mask invalid positions before taking max
548
+ contact_map_masked = contact_map.masked_fill(~valid_mask, 0.0)
549
+ contact_score = contact_map_masked.flatten(1).max(dim=-1).values
550
+
551
+ if not return_dict:
552
+ return (contact_map, contact_score, clip_embed1, clip_embed2, clip_score)
553
+
554
+ return FlashPPIOutput(
555
+ contact_map=contact_map,
556
+ contact_score=contact_score,
557
+ clip_embed1=clip_embed1,
558
+ clip_embed2=clip_embed2,
559
+ clip_score=clip_score,
560
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "<cls>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<eos>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "mask_token": {
17
+ "content": "<mask>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "pad_token": {
24
+ "content": "<pad>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "sep_token": {
31
+ "content": "<sep>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "unk_token": {
38
+ "content": "<unk>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ }
44
+ }
tokenizer.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<cls>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<pad>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<eos>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<unk>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 33,
44
+ "content": "<+>",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 34,
53
+ "content": "<->",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ },
60
+ {
61
+ "id": 35,
62
+ "content": "<mask>",
63
+ "single_word": false,
64
+ "lstrip": false,
65
+ "rstrip": false,
66
+ "normalized": false,
67
+ "special": true
68
+ },
69
+ {
70
+ "id": 36,
71
+ "content": "<sep>",
72
+ "single_word": false,
73
+ "lstrip": false,
74
+ "rstrip": false,
75
+ "normalized": false,
76
+ "special": true
77
+ }
78
+ ],
79
+ "normalizer": null,
80
+ "pre_tokenizer": null,
81
+ "post_processor": null,
82
+ "decoder": null,
83
+ "model": {
84
+ "type": "BPE",
85
+ "dropout": null,
86
+ "unk_token": "<unk>",
87
+ "continuing_subword_prefix": null,
88
+ "end_of_word_suffix": null,
89
+ "fuse_unk": false,
90
+ "byte_fallback": false,
91
+ "ignore_merges": false,
92
+ "vocab": {
93
+ "<cls>": 0,
94
+ "<pad>": 1,
95
+ "<eos>": 2,
96
+ "<unk>": 3,
97
+ "L": 4,
98
+ "A": 5,
99
+ "G": 6,
100
+ "V": 7,
101
+ "S": 8,
102
+ "E": 9,
103
+ "R": 10,
104
+ "T": 11,
105
+ "I": 12,
106
+ "D": 13,
107
+ "P": 14,
108
+ "K": 15,
109
+ "Q": 16,
110
+ "N": 17,
111
+ "F": 18,
112
+ "Y": 19,
113
+ "M": 20,
114
+ "H": 21,
115
+ "W": 22,
116
+ "C": 23,
117
+ "X": 24,
118
+ "B": 25,
119
+ "U": 26,
120
+ "Z": 27,
121
+ "O": 28,
122
+ "a": 29,
123
+ "t": 30,
124
+ "c": 31,
125
+ "g": 32,
126
+ "<+>": 33,
127
+ "<->": 34,
128
+ "<mask>": 35,
129
+ "<sep>": 36
130
+ },
131
+ "merges": []
132
+ }
133
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "33": {
36
+ "content": "<+>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "34": {
44
+ "content": "<->",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "35": {
52
+ "content": "<mask>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "36": {
60
+ "content": "<sep>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ }
67
+ },
68
+ "auto_map": {
69
+ "AutoTokenizer": [
70
+ "glm_tokenizer.gLM2Tokenizer",
71
+ null
72
+ ]
73
+ },
74
+ "clean_up_tokenization_spaces": true,
75
+ "cls_token": "<cls>",
76
+ "eos_token": "<eos>",
77
+ "extra_special_tokens": {},
78
+ "mask_token": "<mask>",
79
+ "model_max_length": 1000000000000000019884624838656,
80
+ "pad_token": "<pad>",
81
+ "sep_token": "<sep>",
82
+ "tokenizer_class": "gLM2Tokenizer",
83
+ "unk_token": "<unk>"
84
+ }