hbeen's picture
Add LoGoBERT-PPI model implementation
c49f380
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from huggingface_hub import PyTorchModelHubMixin
class LoGo_BERT(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
model_name: str = "facebook/esm2_t33_650M_UR50D",
embedding_dim: int = 512,
dropout: float = 0.1,
pos_weight: float = 1.0,
use_ln_g1: bool = True,
score_norm: str = "none",
hidden_mult: int = 1,
act: str = "relu",
score_fn: str = "dot",
use_maxsim: bool = True,
):
super().__init__()
self._init_args = dict(
model_name=model_name,
embedding_dim=embedding_dim,
dropout=dropout,
pos_weight=float(pos_weight),
use_ln_g1=bool(use_ln_g1),
score_norm=str(score_norm),
hidden_mult=int(hidden_mult),
act=str(act),
score_fn=str(score_fn),
use_maxsim=bool(use_maxsim),
)
self.encoder = AutoModel.from_pretrained(model_name)
self.projection = nn.Linear(self.encoder.config.hidden_size, embedding_dim)
self.dropout = nn.Dropout(dropout)
self.act = nn.SiLU() if act.lower() == "silu" else nn.ReLU()
input_dim = 3 * embedding_dim + 1
h = embedding_dim * hidden_mult
self.classifier = nn.Sequential(
nn.Linear(input_dim, h),
self.act,
nn.Dropout(dropout),
nn.Linear(h, 1),
)
self.sbert_weight = nn.Parameter(torch.ones(3 * embedding_dim))
self.maxsim_weight = nn.Parameter(torch.ones(1))
self.use_ln_g1 = use_ln_g1
self.use_maxsim = use_maxsim
if self.use_ln_g1:
self.ln_g1 = nn.LayerNorm(3 * embedding_dim)
self.score_norm = score_norm.lower()
self.score_fn = score_fn
self.register_buffer("pos_weight_buf", torch.tensor([float(pos_weight)]))
@property
def config(self):
return self._init_args
def encode(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
token_embed = self.dropout(self.projection(outputs.last_hidden_state))
return token_embed
def mean_pooling(self, embed, mask):
mask = mask.unsqueeze(-1).float()
summed = torch.sum(embed * mask, dim=1)
counted = mask.sum(dim=1).clamp(min=1e-9)
return summed / counted
def maxsim_dot(self, emb_a, mask_a, emb_b, mask_b, score_fn="dot"):
if score_fn == "dot":
sim_matrix = torch.bmm(emb_a, emb_b.transpose(1, 2))
elif score_fn == "cosine":
a = F.normalize(emb_a, dim=-1)
b = F.normalize(emb_b, dim=-1)
sim_matrix = torch.bmm(a, b.transpose(1, 2))
else:
raise ValueError(f"Invalid mode: {score_fn}. Choose 'dot' or 'cosine'.")
neg_inf = torch.finfo(sim_matrix.dtype).min
sim_matrix = sim_matrix.masked_fill(~mask_a[:, :, None].bool(), neg_inf)
sim_matrix = sim_matrix.masked_fill(~mask_b[:, None, :].bool(), neg_inf)
max_per_query = sim_matrix.max(dim=2).values
has_valid_key = mask_b.any(dim=1, keepdim=True)
max_per_query = torch.where(has_valid_key, max_per_query, torch.zeros_like(max_per_query))
mask_a_float = mask_a.float()
max_per_query = max_per_query * mask_a_float
summed = torch.sum(max_per_query, dim=1, keepdim=True)
valid_len = mask_a_float.sum(dim=1, keepdim=True).clamp(min=1e-9)
return summed / valid_len
def forward(self, input_a, input_b, labels=None):
emb_a = self.encode(input_a["input_ids"], input_a["attention_mask"])
emb_b = self.encode(input_b["input_ids"], input_b["attention_mask"])
pooled_a = self.mean_pooling(emb_a, input_a["attention_mask"])
pooled_b = self.mean_pooling(emb_b, input_b["attention_mask"])
abs_diff = torch.abs(pooled_a - pooled_b)
group1 = torch.cat([pooled_a, pooled_b, abs_diff], dim=1)
if self.use_ln_g1:
group1 = self.ln_g1(group1)
weighted_group1 = group1 * self.sbert_weight
if self.use_maxsim:
score = self.maxsim_dot(
emb_a, input_a["attention_mask"],
emb_b, input_b["attention_mask"],
score_fn=self.score_fn,
)
if self.score_norm == "tanh":
score = torch.tanh(score)
weighted_group2 = score * self.maxsim_weight
else:
weighted_group2 = pooled_a.new_zeros(pooled_a.size(0), 1)
concat = torch.cat([weighted_group1, weighted_group2], dim=1)
logits = self.classifier(concat).squeeze(-1)
if labels is not None:
loss_fn = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight_buf.to(logits.device))
loss = loss_fn(logits, labels.float())
return loss, logits
return torch.sigmoid(logits)
@torch.inference_mode()
def predict_from_embeds(self, emb_a, mask_a, emb_b, mask_b):
pooled_a = self.mean_pooling(emb_a, mask_a)
pooled_b = self.mean_pooling(emb_b, mask_b)
abs_diff = torch.abs(pooled_a - pooled_b)
group1 = torch.cat([pooled_a, pooled_b, abs_diff], dim=1)
if self.use_ln_g1:
group1 = self.ln_g1(group1)
weighted_group1 = group1 * self.sbert_weight
if self.use_maxsim:
score = self.maxsim_dot(
emb_a, mask_a, emb_b, mask_b,
score_fn=self.score_fn
)
if self.score_norm == "tanh":
score = torch.tanh(score)
else:
score = pooled_a.new_zeros(pooled_a.size(0), 1)
weighted_group2 = score * self.maxsim_weight
concat = torch.cat([weighted_group1, weighted_group2], dim=1) # (B, 3D+1)
logits = self.classifier(concat).squeeze(-1)
probs = torch.sigmoid(logits)
return probs