| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Union |
| from einops import rearrange, repeat |
| from torch.utils.checkpoint import checkpoint |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| from .configuration_flashppi import FlashPPIConfig |
|
|
| |
| try: |
| from flash_attn.layers.rotary import apply_rotary_emb_func |
| from flash_attn import flash_attn_varlen_kvpacked_func |
| from flash_attn.bert_padding import pad_input, unpad_input |
| FLASH_ATTN_AVAILABLE = True |
| except ImportError: |
| FLASH_ATTN_AVAILABLE = False |
| unpad_input = pad_input = apply_rotary_emb_func = None |
| flash_attn_varlen_kvpacked_func = None |
|
|
| def swiglu(x, y): |
| return F.silu(x) * y |
|
|
| class RMSNorm(nn.Module): |
| """RMSNorm without variance_epsilon buffer for checkpoint compatibility.""" |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) |
| return (self.weight * hidden_states).to(input_dtype) |
| |
| @dataclass |
| class FlashPPIOutput(ModelOutput): |
| """Output type for FlashPPI model. |
| |
| Args: |
| contact_map: (B, L1, L2) contact probabilities between residue pairs. |
| contact_score: (B,) maximum contact probability per pair. |
| clip_embed1: (B, D) CLIP embedding for first protein. |
| clip_embed2: (B, D) CLIP embedding for second protein. |
| clip_score: (B,) CLIP similarity score (cosine similarity). |
| """ |
| contact_map: Optional[torch.FloatTensor] = None |
| contact_score: Optional[torch.FloatTensor] = None |
| clip_embed1: Optional[torch.FloatTensor] = None |
| clip_embed2: Optional[torch.FloatTensor] = None |
| clip_score: Optional[torch.FloatTensor] = None |
|
|
|
|
| def rotate_half(x, interleaved=False): |
| if not interleaved: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
| else: |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) |
|
|
|
|
| def apply_rotary_emb_torch(x, cos, sin, interleaved=False, position_ids=None): |
| """Apply rotary embeddings using pure PyTorch.""" |
| if position_ids is not None: |
| cos = cos[position_ids] |
| sin = sin[position_ids] |
| else: |
| cos = cos[:x.shape[1]] |
| sin = sin[:x.shape[1]] |
|
|
| if not interleaved: |
| cos = repeat(cos, "... d -> ... 1 (2 d)") |
| sin = repeat(sin, "... d -> ... 1 (2 d)") |
| else: |
| cos = repeat(cos, "... d -> ... 1 (d 2)") |
| sin = repeat(sin, "... d -> ... 1 (d 2)") |
|
|
| ro_dim = cos.shape[-1] |
| return torch.cat([ |
| x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, |
| x[..., ro_dim:], |
| ], dim=-1) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary position embeddings with flash attention support.""" |
|
|
| def __init__(self, dim: int, base: float = 10000.0, interleaved: bool = False, device=None): |
| super().__init__() |
| self.dim = dim |
| self.base = float(base) |
| self.interleaved = interleaved |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._seq_len_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
|
|
| def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
| if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device: |
| self._seq_len_cached = seqlen |
| t = torch.arange(seqlen, device=device, dtype=torch.float32) |
| freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) |
| self._cos_cached = torch.cos(freqs).to(dtype) |
| self._sin_cached = torch.sin(freqs).to(dtype) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| seqlen = q.shape[1] if max_seqlen is None else max_seqlen |
| self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) |
|
|
| if FLASH_ATTN_AVAILABLE and cu_seqlens is not None: |
| q = apply_rotary_emb_func( |
| q, self._cos_cached, self._sin_cached, |
| interleaved=self.interleaved, inplace=True, |
| cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, |
| ) |
| k = apply_rotary_emb_func( |
| k, self._cos_cached, self._sin_cached, |
| interleaved=self.interleaved, inplace=True, |
| cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, |
| ) |
| else: |
| q = apply_rotary_emb_torch(q, self._cos_cached, self._sin_cached, self.interleaved, position_ids) |
| k = apply_rotary_emb_torch(k, self._cos_cached, self._sin_cached, self.interleaved, position_ids) |
| return q, k |
|
|
|
|
| class Attention(nn.Module): |
| """Multi-head attention with optional flash attention.""" |
|
|
| def __init__(self, dim: int, num_heads: int, use_rope: bool = True): |
| super().__init__() |
| self.n_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.wqkv = nn.Linear(dim, num_heads * self.head_dim * 3, bias=False) |
| self.wo = nn.Linear(num_heads * self.head_dim, dim, bias=False) |
| self.rotary_emb = RotaryEmbedding(self.head_dim) if use_rope else None |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seq_len: Optional[int] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| qkv = self.wqkv(x) |
|
|
| if cu_seqlens is not None and FLASH_ATTN_AVAILABLE: |
| |
| total_seqlen = x.shape[0] |
| q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1) |
| q = q.view(total_seqlen, self.n_heads, self.head_dim) |
| k = k.view(total_seqlen, self.n_heads, self.head_dim) |
| v = v.view(total_seqlen, self.n_heads, self.head_dim) |
|
|
| if self.rotary_emb is not None: |
| q, k = self.rotary_emb(q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len) |
|
|
| kv = torch.stack([k, v], 1) |
| output = flash_attn_varlen_kvpacked_func( |
| q, kv, |
| cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len, |
| dropout_p=0.0, causal=False, |
| ) |
| output = output.view(total_seqlen, self.n_heads * self.head_dim) |
| else: |
| |
| bsz, seqlen, _ = x.shape |
| q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1) |
| q = q.view(bsz, seqlen, self.n_heads, self.head_dim) |
| k = k.view(bsz, seqlen, self.n_heads, self.head_dim) |
| v = v.view(bsz, seqlen, self.n_heads, self.head_dim) |
|
|
| if self.rotary_emb is not None: |
| q, k = self.rotary_emb(q, k, position_ids=position_ids) |
|
|
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
| attn_mask = None |
| if attention_mask is not None: |
| attn_mask = attention_mask.unsqueeze(1).unsqueeze(2).bool() |
|
|
| output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False) |
| output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_heads * self.head_dim) |
|
|
| return self.wo(output) |
|
|
|
|
| class FeedForward(nn.Module): |
| """SwiGLU feedforward network.""" |
|
|
| def __init__(self, dim: int, hidden_mult: float = 4.0, multiple_of: int = 256, ffn_dim_multiplier: float = None): |
| super().__init__() |
| hidden_dim = int(2 * dim * hidden_mult / 3) |
| if ffn_dim_multiplier is not None: |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(swiglu(self.w1(x), self.w3(x))) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm transformer block.""" |
|
|
| def __init__(self, dim: int, num_heads: int, norm_eps: float = 1e-6, |
| multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True): |
| super().__init__() |
| self.attention = Attention(dim, num_heads, use_rope) |
| self.feed_forward = FeedForward(dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier) |
| self.attention_norm = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm = RMSNorm(dim, eps=norm_eps) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seq_len: Optional[int] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| h = x + self.attention(self.attention_norm(x), cu_seqlens, max_seq_len, attention_mask, position_ids) |
| return h + self.feed_forward(self.ffn_norm(h)) |
|
|
|
|
| class TransformerLayers(nn.Module): |
| """Stack of transformer blocks with optional flash attention.""" |
|
|
| def __init__(self, dim: int, num_heads: int, depth: int, norm_eps: float = 1e-6, |
| multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True): |
| super().__init__() |
| self.dim = dim |
| self.layers = nn.ModuleList([ |
| TransformerBlock(dim, num_heads, norm_eps, multiple_of, ffn_dim_multiplier, use_rope) |
| for _ in range(depth) |
| ]) |
| self.gradient_checkpointing = False |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| batch_size, seq_len = x.shape[:2] |
| cu_seqlens, max_seq_len_in_batch, indices, position_ids = None, None, None, None |
|
|
| if FLASH_ATTN_AVAILABLE and attention_mask is not None and not attention_mask.all(): |
| x, indices, cu_seqlens, max_seq_len_in_batch, _ = unpad_input(x, attention_mask) |
| mask_for_layers = None |
| elif attention_mask is not None: |
| mask_long = attention_mask.long() |
| position_ids = (mask_long.cumsum(dim=1) - 1).clamp(min=0) |
| mask_for_layers = attention_mask |
| else: |
| mask_for_layers = None |
|
|
| for layer in self.layers: |
| if self.training and self.gradient_checkpointing: |
| x = checkpoint(layer, x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids, use_reentrant=False) |
| else: |
| x = layer(x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids) |
|
|
| if FLASH_ATTN_AVAILABLE and indices is not None: |
| x = pad_input(x, indices, batch_size, seq_len) |
|
|
| return x |
|
|
|
|
| class GLM2Backbone(nn.Module): |
| """gLM2 protein language model backbone.""" |
|
|
| def __init__(self, config: FlashPPIConfig): |
| super().__init__() |
| self.config = config |
| self.tok_embeddings = nn.Embedding(config.plm_vocab_size, config.plm_dim) |
| self.encoder = TransformerLayers( |
| dim=config.plm_dim, |
| num_heads=config.plm_heads, |
| depth=config.plm_depth, |
| norm_eps=config.plm_norm_eps, |
| multiple_of=config.plm_swiglu_multiple_of, |
| ffn_dim_multiplier=config.plm_ffn_dim_multiplier, |
| use_rope=True, |
| ) |
|
|
| def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| h = self.tok_embeddings(input_ids) |
| return self.encoder(h, attention_mask) |
|
|
|
|
| class MLPHead(nn.Module): |
| """SwiGLU MLP projection head.""" |
|
|
| def __init__(self, in_dim: int, out_dim: int, hidden_mult: float = 2.0): |
| super().__init__() |
| hidden_dim = int(in_dim * hidden_mult) |
| self.w1 = nn.Linear(in_dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, out_dim, bias=False) |
| self.w3 = nn.Linear(in_dim, hidden_dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class ContrastiveHead(nn.Module): |
| """CLIP-style contrastive head with mean pooling.""" |
|
|
| def __init__(self, hidden_dim: int, embed_dim: int): |
| super().__init__() |
| self.head = MLPHead(hidden_dim, embed_dim) |
|
|
| def forward(self, residue_embeds: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| mask = mask.unsqueeze(-1).bool() |
| embeds = torch.where(mask, residue_embeds, 0.0) |
| embeds = embeds.sum(dim=1) / mask.sum(dim=1).float().clamp(min=1.0) |
| return F.normalize(self.head(embeds), dim=-1) |
|
|
|
|
| class ContactHead(nn.Module): |
| """Contact prediction head using cross-attention between protein pairs.""" |
|
|
| def __init__(self, input_dim: int, contact_dim: int, num_heads: int = 8, depth: int = 2): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dim = contact_dim // num_heads |
| assert contact_dim % num_heads == 0 |
|
|
| self.segment_embed = nn.Embedding(2, input_dim) |
| nn.init.normal_(self.segment_embed.weight, std=0.02) |
| |
| self.transformer = TransformerLayers(input_dim, num_heads, depth, use_rope=True) |
| self.norm = nn.LayerNorm(input_dim) |
| self.q_proj = nn.Linear(input_dim, contact_dim, bias=True) |
| self.k_proj = nn.Linear(input_dim, contact_dim, bias=True) |
| self.output_mix = nn.Linear(num_heads, 1) |
| nn.init.constant_(self.output_mix.bias, -3.0) |
| self.scale = self.head_dim ** -0.5 |
|
|
| def forward( |
| self, |
| embed1: torch.Tensor, |
| embed2: torch.Tensor, |
| mask1: torch.Tensor, |
| mask2: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, L1, D = embed1.shape |
| _, L2, _ = embed2.shape |
|
|
| seg1 = self.segment_embed(torch.zeros(L1, device=embed1.device, dtype=torch.long)) |
| seg2 = self.segment_embed(torch.ones(L2, device=embed1.device, dtype=torch.long)) |
|
|
| x = torch.cat([embed1 + seg1.unsqueeze(0), embed2 + seg2.unsqueeze(0)], dim=1) |
| combined_mask = torch.cat([mask1, mask2], dim=1).bool() if mask1 is not None and mask2 is not None else None |
|
|
| x = self.transformer(x, attention_mask=combined_mask) |
|
|
| embed1 = self.norm(x[:, :L1, :]) |
| embed2 = self.norm(x[:, L1:, :]) |
|
|
| q = self.q_proj(embed1).view(B, L1, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(embed2).view(B, L2, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| attn_logits = attn_logits.permute(0, 2, 3, 1).contiguous() |
| contact_logits = self.output_mix(attn_logits).squeeze(-1) |
|
|
| if mask1 is not None and mask2 is not None: |
| valid_mask = (mask1.unsqueeze(2) * mask2.unsqueeze(1)).bool() |
| else: |
| valid_mask = torch.ones_like(contact_logits, dtype=torch.bool) |
|
|
| return contact_logits, valid_mask |
|
|
|
|
| class FlashPPIPreTrainedModel(PreTrainedModel): |
| """Base class for FlashPPI models.""" |
| |
| config_class = FlashPPIConfig |
| base_model_prefix = "flashppi" |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, std=0.02) |
| elif isinstance(module, RotaryEmbedding): |
| |
| inv_freq = 1.0 / ( |
| module.base |
| ** ( |
| torch.arange(0, module.dim, 2, device=module.inv_freq.device, dtype=torch.float32) |
| / module.dim |
| ) |
| ) |
| |
| with torch.no_grad(): |
| module.inv_freq.copy_(inv_freq) |
|
|
| class FlashPPIModel(FlashPPIPreTrainedModel): |
| """FlashPPI model.""" |
|
|
| def __init__(self, config: FlashPPIConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.plm = GLM2Backbone(config) |
|
|
| |
| self.head_q = ContrastiveHead(config.plm_dim, config.clip_embed_dim) |
| self.head_k = ContrastiveHead(config.plm_dim, config.clip_embed_dim) |
| self.logit_scale = nn.Parameter(torch.ones([]) * 2.6593) |
|
|
| |
| self.contact_head = ContactHead( |
| config.plm_dim, |
| config.contact_embed_dim, |
| num_heads=config.contact_num_heads, |
| depth=config.contact_transformer_depth, |
| ) |
|
|
| self.post_init() |
|
|
| def encode_protein( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Encode a protein sequence to residue-level embeddings. |
| |
| Args: |
| input_ids: (B, L) token IDs from gLM2 tokenizer. |
| attention_mask: (B, L) attention mask. |
| |
| Returns: |
| (B, L, plm_dim) residue embeddings. |
| """ |
| return self.plm(input_ids, attention_mask) |
|
|
| def predict_contacts( |
| self, |
| embed1: torch.Tensor, |
| embed2: torch.Tensor, |
| mask1: torch.Tensor, |
| mask2: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Predict contact map from pre-computed residue embeddings. |
| |
| This method is useful for efficient 2-stage inference where embeddings |
| are pre-computed and cached. |
| |
| Args: |
| embed1: (B, L1, D) residue embeddings for protein 1. |
| embed2: (B, L2, D) residue embeddings for protein 2. |
| mask1: (B, L1) attention mask for protein 1. |
| mask2: (B, L2) attention mask for protein 2. |
| |
| Returns: |
| contact_logits: (B, L1, L2) raw logits. |
| valid_mask: (B, L1, L2) mask for valid positions. |
| """ |
| return self.contact_head(embed1, embed2, mask1, mask2) |
|
|
| def forward( |
| self, |
| input_ids1: torch.Tensor, |
| input_ids2: torch.Tensor, |
| attention_mask1: Optional[torch.Tensor] = None, |
| attention_mask2: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ) -> Union[Tuple, FlashPPIOutput]: |
| """Forward pass for protein pair interaction prediction. |
| |
| Args: |
| input_ids1: (B, L1) token IDs for protein 1. |
| input_ids2: (B, L2) token IDs for protein 2. |
| attention_mask1: (B, L1) attention mask for protein 1. |
| attention_mask2: (B, L2) attention mask for protein 2. |
| return_dict: Whether to return a FlashPPIOutput or tuple. |
| |
| Returns: |
| FlashPPIOutput with contact predictions and CLIP embeddings. |
| """ |
| B = input_ids1.shape[0] |
| L1, L2 = input_ids1.shape[1], input_ids2.shape[1] |
| |
| if attention_mask1 is None: |
| attention_mask1 = torch.ones_like(input_ids1) |
| if attention_mask2 is None: |
| attention_mask2 = torch.ones_like(input_ids2) |
|
|
| |
| |
| if L1 != L2: |
| max_len = max(L1, L2) |
| if L1 < max_len: |
| pad_len = max_len - L1 |
| input_ids1 = F.pad(input_ids1, (0, pad_len), value=0) |
| attention_mask1 = F.pad(attention_mask1, (0, pad_len), value=0) |
| if L2 < max_len: |
| pad_len = max_len - L2 |
| input_ids2 = F.pad(input_ids2, (0, pad_len), value=0) |
| attention_mask2 = F.pad(attention_mask2, (0, pad_len), value=0) |
| |
| |
| batched_input_ids = torch.cat([input_ids1, input_ids2], dim=0) |
| batched_attention_mask = torch.cat([attention_mask1, attention_mask2], dim=0) |
| batched_embeds = self.encode_protein(batched_input_ids, batched_attention_mask) |
| |
| |
| residue_embeds1 = batched_embeds[:B, :L1, :] |
| residue_embeds2 = batched_embeds[B:, :L2, :] |
| attention_mask1 = attention_mask1[:, :L1] |
| attention_mask2 = attention_mask2[:, :L2] |
|
|
| |
| clip_embed1 = self.head_q(residue_embeds1, attention_mask1) |
| clip_embed2 = self.head_k(residue_embeds2, attention_mask2) |
| clip_score = (clip_embed1 * clip_embed2).sum(dim=-1) |
|
|
| |
| contact_logits, valid_mask = self.contact_head( |
| residue_embeds1, residue_embeds2, attention_mask1, attention_mask2 |
| ) |
| contact_map = torch.sigmoid(contact_logits) |
| |
| |
| contact_map_masked = contact_map.masked_fill(~valid_mask, 0.0) |
| contact_score = contact_map_masked.flatten(1).max(dim=-1).values |
|
|
| if not return_dict: |
| return (contact_map, contact_score, clip_embed1, clip_embed2, clip_score) |
|
|
| return FlashPPIOutput( |
| contact_map=contact_map, |
| contact_score=contact_score, |
| clip_embed1=clip_embed1, |
| clip_embed2=clip_embed2, |
| clip_score=clip_score, |
| ) |