flashppi / modeling_flashppi.py
andrecornman's picture
Use torch swiglu
d1335fb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from einops import rearrange, repeat
from torch.utils.checkpoint import checkpoint
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_flashppi import FlashPPIConfig
# Detect Flash Attention installation
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
FLASH_ATTN_AVAILABLE = True
except ImportError:
FLASH_ATTN_AVAILABLE = False
unpad_input = pad_input = apply_rotary_emb_func = None
flash_attn_varlen_kvpacked_func = None
def swiglu(x, y):
return F.silu(x) * y
class RMSNorm(nn.Module):
"""RMSNorm without variance_epsilon buffer for checkpoint compatibility."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight * hidden_states).to(input_dtype)
@dataclass
class FlashPPIOutput(ModelOutput):
"""Output type for FlashPPI model.
Args:
contact_map: (B, L1, L2) contact probabilities between residue pairs.
contact_score: (B,) maximum contact probability per pair.
clip_embed1: (B, D) CLIP embedding for first protein.
clip_embed2: (B, D) CLIP embedding for second protein.
clip_score: (B,) CLIP similarity score (cosine similarity).
"""
contact_map: Optional[torch.FloatTensor] = None
contact_score: Optional[torch.FloatTensor] = None
clip_embed1: Optional[torch.FloatTensor] = None
clip_embed2: Optional[torch.FloatTensor] = None
clip_score: Optional[torch.FloatTensor] = None
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False, position_ids=None):
"""Apply rotary embeddings using pure PyTorch."""
if position_ids is not None:
cos = cos[position_ids]
sin = sin[position_ids]
else:
cos = cos[:x.shape[1]]
sin = sin[:x.shape[1]]
if not interleaved:
cos = repeat(cos, "... d -> ... 1 (2 d)")
sin = repeat(sin, "... d -> ... 1 (2 d)")
else:
cos = repeat(cos, "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (d 2)")
ro_dim = cos.shape[-1]
return torch.cat([
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
], dim=-1)
class RotaryEmbedding(nn.Module):
"""Rotary position embeddings with flash attention support."""
def __init__(self, dim: int, base: float = 10000.0, interleaved: bool = False, device=None):
super().__init__()
self.dim = dim
self.base = float(base)
self.interleaved = interleaved
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seqlen = q.shape[1] if max_seqlen is None else max_seqlen
self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)
if FLASH_ATTN_AVAILABLE and cu_seqlens is not None:
q = apply_rotary_emb_func(
q, self._cos_cached, self._sin_cached,
interleaved=self.interleaved, inplace=True,
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
)
k = apply_rotary_emb_func(
k, self._cos_cached, self._sin_cached,
interleaved=self.interleaved, inplace=True,
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
)
else:
q = apply_rotary_emb_torch(q, self._cos_cached, self._sin_cached, self.interleaved, position_ids)
k = apply_rotary_emb_torch(k, self._cos_cached, self._sin_cached, self.interleaved, position_ids)
return q, k
class Attention(nn.Module):
"""Multi-head attention with optional flash attention."""
def __init__(self, dim: int, num_heads: int, use_rope: bool = True):
super().__init__()
self.n_heads = num_heads
self.head_dim = dim // num_heads
self.wqkv = nn.Linear(dim, num_heads * self.head_dim * 3, bias=False)
self.wo = nn.Linear(num_heads * self.head_dim, dim, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim) if use_rope else None
def forward(
self,
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.wqkv(x)
if cu_seqlens is not None and FLASH_ATTN_AVAILABLE:
# Flash attention path (unpadded)
total_seqlen = x.shape[0]
q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
q = q.view(total_seqlen, self.n_heads, self.head_dim)
k = k.view(total_seqlen, self.n_heads, self.head_dim)
v = v.view(total_seqlen, self.n_heads, self.head_dim)
if self.rotary_emb is not None:
q, k = self.rotary_emb(q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
kv = torch.stack([k, v], 1)
output = flash_attn_varlen_kvpacked_func(
q, kv,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len,
dropout_p=0.0, causal=False,
)
output = output.view(total_seqlen, self.n_heads * self.head_dim)
else:
# SDPA path (padded)
bsz, seqlen, _ = x.shape
q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
if self.rotary_emb is not None:
q, k = self.rotary_emb(q, k, position_ids=position_ids)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.unsqueeze(1).unsqueeze(2).bool()
output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_heads * self.head_dim)
return self.wo(output)
class FeedForward(nn.Module):
"""SwiGLU feedforward network."""
def __init__(self, dim: int, hidden_mult: float = 4.0, multiple_of: int = 256, ffn_dim_multiplier: float = None):
super().__init__()
hidden_dim = int(2 * dim * hidden_mult / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(swiglu(self.w1(x), self.w3(x)))
class TransformerBlock(nn.Module):
"""Pre-norm transformer block."""
def __init__(self, dim: int, num_heads: int, norm_eps: float = 1e-6,
multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True):
super().__init__()
self.attention = Attention(dim, num_heads, use_rope)
self.feed_forward = FeedForward(dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier)
self.attention_norm = RMSNorm(dim, eps=norm_eps)
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
h = x + self.attention(self.attention_norm(x), cu_seqlens, max_seq_len, attention_mask, position_ids)
return h + self.feed_forward(self.ffn_norm(h))
class TransformerLayers(nn.Module):
"""Stack of transformer blocks with optional flash attention."""
def __init__(self, dim: int, num_heads: int, depth: int, norm_eps: float = 1e-6,
multiple_of: int = 256, ffn_dim_multiplier: float = None, use_rope: bool = True):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([
TransformerBlock(dim, num_heads, norm_eps, multiple_of, ffn_dim_multiplier, use_rope)
for _ in range(depth)
])
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size, seq_len = x.shape[:2]
cu_seqlens, max_seq_len_in_batch, indices, position_ids = None, None, None, None
if FLASH_ATTN_AVAILABLE and attention_mask is not None and not attention_mask.all():
x, indices, cu_seqlens, max_seq_len_in_batch, _ = unpad_input(x, attention_mask)
mask_for_layers = None
elif attention_mask is not None:
mask_long = attention_mask.long()
position_ids = (mask_long.cumsum(dim=1) - 1).clamp(min=0)
mask_for_layers = attention_mask
else:
mask_for_layers = None
for layer in self.layers:
if self.training and self.gradient_checkpointing:
x = checkpoint(layer, x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids, use_reentrant=False)
else:
x = layer(x, cu_seqlens, max_seq_len_in_batch, mask_for_layers, position_ids)
if FLASH_ATTN_AVAILABLE and indices is not None:
x = pad_input(x, indices, batch_size, seq_len)
return x
class GLM2Backbone(nn.Module):
"""gLM2 protein language model backbone."""
def __init__(self, config: FlashPPIConfig):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(config.plm_vocab_size, config.plm_dim)
self.encoder = TransformerLayers(
dim=config.plm_dim,
num_heads=config.plm_heads,
depth=config.plm_depth,
norm_eps=config.plm_norm_eps,
multiple_of=config.plm_swiglu_multiple_of,
ffn_dim_multiplier=config.plm_ffn_dim_multiplier,
use_rope=True,
)
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.tok_embeddings(input_ids)
return self.encoder(h, attention_mask)
class MLPHead(nn.Module):
"""SwiGLU MLP projection head."""
def __init__(self, in_dim: int, out_dim: int, hidden_mult: float = 2.0):
super().__init__()
hidden_dim = int(in_dim * hidden_mult)
self.w1 = nn.Linear(in_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
self.w3 = nn.Linear(in_dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ContrastiveHead(nn.Module):
"""CLIP-style contrastive head with mean pooling."""
def __init__(self, hidden_dim: int, embed_dim: int):
super().__init__()
self.head = MLPHead(hidden_dim, embed_dim)
def forward(self, residue_embeds: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1).bool()
embeds = torch.where(mask, residue_embeds, 0.0)
embeds = embeds.sum(dim=1) / mask.sum(dim=1).float().clamp(min=1.0)
return F.normalize(self.head(embeds), dim=-1)
class ContactHead(nn.Module):
"""Contact prediction head using cross-attention between protein pairs."""
def __init__(self, input_dim: int, contact_dim: int, num_heads: int = 8, depth: int = 2):
super().__init__()
self.num_heads = num_heads
self.head_dim = contact_dim // num_heads
assert contact_dim % num_heads == 0
self.segment_embed = nn.Embedding(2, input_dim)
nn.init.normal_(self.segment_embed.weight, std=0.02)
self.transformer = TransformerLayers(input_dim, num_heads, depth, use_rope=True)
self.norm = nn.LayerNorm(input_dim)
self.q_proj = nn.Linear(input_dim, contact_dim, bias=True)
self.k_proj = nn.Linear(input_dim, contact_dim, bias=True)
self.output_mix = nn.Linear(num_heads, 1)
nn.init.constant_(self.output_mix.bias, -3.0)
self.scale = self.head_dim ** -0.5
def forward(
self,
embed1: torch.Tensor,
embed2: torch.Tensor,
mask1: torch.Tensor,
mask2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L1, D = embed1.shape
_, L2, _ = embed2.shape
seg1 = self.segment_embed(torch.zeros(L1, device=embed1.device, dtype=torch.long))
seg2 = self.segment_embed(torch.ones(L2, device=embed1.device, dtype=torch.long))
x = torch.cat([embed1 + seg1.unsqueeze(0), embed2 + seg2.unsqueeze(0)], dim=1)
combined_mask = torch.cat([mask1, mask2], dim=1).bool() if mask1 is not None and mask2 is not None else None
x = self.transformer(x, attention_mask=combined_mask)
embed1 = self.norm(x[:, :L1, :])
embed2 = self.norm(x[:, L1:, :])
q = self.q_proj(embed1).view(B, L1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(embed2).view(B, L2, self.num_heads, self.head_dim).transpose(1, 2)
attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_logits = attn_logits.permute(0, 2, 3, 1).contiguous()
contact_logits = self.output_mix(attn_logits).squeeze(-1)
if mask1 is not None and mask2 is not None:
valid_mask = (mask1.unsqueeze(2) * mask2.unsqueeze(1)).bool()
else:
valid_mask = torch.ones_like(contact_logits, dtype=torch.bool)
return contact_logits, valid_mask
class FlashPPIPreTrainedModel(PreTrainedModel):
"""Base class for FlashPPI models."""
config_class = FlashPPIConfig
base_model_prefix = "flashppi"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
elif isinstance(module, RotaryEmbedding):
# Re-calculate the frequencies using the module's stored attributes
inv_freq = 1.0 / (
module.base
** (
torch.arange(0, module.dim, 2, device=module.inv_freq.device, dtype=torch.float32)
/ module.dim
)
)
# Force the buffer to update
with torch.no_grad():
module.inv_freq.copy_(inv_freq)
class FlashPPIModel(FlashPPIPreTrainedModel):
"""FlashPPI model."""
def __init__(self, config: FlashPPIConfig):
super().__init__(config)
self.config = config
# gLM2 backbone
self.plm = GLM2Backbone(config)
# CLIP heads (asymmetric for query/key)
self.head_q = ContrastiveHead(config.plm_dim, config.clip_embed_dim)
self.head_k = ContrastiveHead(config.plm_dim, config.clip_embed_dim)
self.logit_scale = nn.Parameter(torch.ones([]) * 2.6593) # ln(1/0.07)
# Contact prediction head
self.contact_head = ContactHead(
config.plm_dim,
config.contact_embed_dim,
num_heads=config.contact_num_heads,
depth=config.contact_transformer_depth,
)
self.post_init()
def encode_protein(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode a protein sequence to residue-level embeddings.
Args:
input_ids: (B, L) token IDs from gLM2 tokenizer.
attention_mask: (B, L) attention mask.
Returns:
(B, L, plm_dim) residue embeddings.
"""
return self.plm(input_ids, attention_mask)
def predict_contacts(
self,
embed1: torch.Tensor,
embed2: torch.Tensor,
mask1: torch.Tensor,
mask2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predict contact map from pre-computed residue embeddings.
This method is useful for efficient 2-stage inference where embeddings
are pre-computed and cached.
Args:
embed1: (B, L1, D) residue embeddings for protein 1.
embed2: (B, L2, D) residue embeddings for protein 2.
mask1: (B, L1) attention mask for protein 1.
mask2: (B, L2) attention mask for protein 2.
Returns:
contact_logits: (B, L1, L2) raw logits.
valid_mask: (B, L1, L2) mask for valid positions.
"""
return self.contact_head(embed1, embed2, mask1, mask2)
def forward(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
attention_mask1: Optional[torch.Tensor] = None,
attention_mask2: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple, FlashPPIOutput]:
"""Forward pass for protein pair interaction prediction.
Args:
input_ids1: (B, L1) token IDs for protein 1.
input_ids2: (B, L2) token IDs for protein 2.
attention_mask1: (B, L1) attention mask for protein 1.
attention_mask2: (B, L2) attention mask for protein 2.
return_dict: Whether to return a FlashPPIOutput or tuple.
Returns:
FlashPPIOutput with contact predictions and CLIP embeddings.
"""
B = input_ids1.shape[0]
L1, L2 = input_ids1.shape[1], input_ids2.shape[1]
if attention_mask1 is None:
attention_mask1 = torch.ones_like(input_ids1)
if attention_mask2 is None:
attention_mask2 = torch.ones_like(input_ids2)
# Encode both proteins in a single batched PLM call for efficiency
# Pad to same length if needed
if L1 != L2:
max_len = max(L1, L2)
if L1 < max_len:
pad_len = max_len - L1
input_ids1 = F.pad(input_ids1, (0, pad_len), value=0)
attention_mask1 = F.pad(attention_mask1, (0, pad_len), value=0)
if L2 < max_len:
pad_len = max_len - L2
input_ids2 = F.pad(input_ids2, (0, pad_len), value=0)
attention_mask2 = F.pad(attention_mask2, (0, pad_len), value=0)
# Batch both sequences for single PLM forward pass
batched_input_ids = torch.cat([input_ids1, input_ids2], dim=0)
batched_attention_mask = torch.cat([attention_mask1, attention_mask2], dim=0)
batched_embeds = self.encode_protein(batched_input_ids, batched_attention_mask)
# Split and trim back to original lengths
residue_embeds1 = batched_embeds[:B, :L1, :]
residue_embeds2 = batched_embeds[B:, :L2, :]
attention_mask1 = attention_mask1[:, :L1]
attention_mask2 = attention_mask2[:, :L2]
# Contrastive embeddings
clip_embed1 = self.head_q(residue_embeds1, attention_mask1)
clip_embed2 = self.head_k(residue_embeds2, attention_mask2)
clip_score = (clip_embed1 * clip_embed2).sum(dim=-1)
# Contact prediction
contact_logits, valid_mask = self.contact_head(
residue_embeds1, residue_embeds2, attention_mask1, attention_mask2
)
contact_map = torch.sigmoid(contact_logits)
# Mask invalid positions before taking max
contact_map_masked = contact_map.masked_fill(~valid_mask, 0.0)
contact_score = contact_map_masked.flatten(1).max(dim=-1).values
if not return_dict:
return (contact_map, contact_score, clip_embed1, clip_embed2, clip_score)
return FlashPPIOutput(
contact_map=contact_map,
contact_score=contact_score,
clip_embed1=clip_embed1,
clip_embed2=clip_embed2,
clip_score=clip_score,
)