nas / PFMBench /src /model /finetune_model.py
yuccaaa's picture
Add files using upload-large-folder tool
9627ce0 verified
import torch
import torch.nn as nn
from src.model.pretrain_model_interface import PretrainModelInterface
class UniModel(nn.Module):
def __init__(
self,
pretrain_model_name: str,
task_type: str,
finetune_type: str,
num_classes: int,
peft_type: str = "lora",
**kwargs
):
super().__init__()
self.pretrain_model_name = pretrain_model_name
self.task_type = task_type
self.finetune_type = finetune_type
hid_dim = 480
if pretrain_model_name == 'esm2_35m':
self.input_dim = 480
if pretrain_model_name == 'esm2_150m':
self.input_dim = 640
if pretrain_model_name == 'esm2_650m':
self.input_dim = 1280
if pretrain_model_name == 'esm2_3b':
self.input_dim = 2560
if pretrain_model_name == 'esm2_15b':
self.input_dim = 5120
if pretrain_model_name == 'esm3_1.4b':
self.input_dim = 1536
if pretrain_model_name == 'esmc_600m':
self.input_dim = 1152
if pretrain_model_name == 'progen2':
self.input_dim = 1536
if pretrain_model_name == 'prostt5':
self.input_dim = 2048
if pretrain_model_name == 'protgpt2':
self.input_dim = 1280
if pretrain_model_name == 'protrek_35m':
self.input_dim = 480*2
if pretrain_model_name == 'protrek':
self.input_dim = 1920
if pretrain_model_name == 'saport':
self.input_dim = 1280
if pretrain_model_name == 'saport_35m':
self.input_dim = 480
if pretrain_model_name == 'saport_1.3b':
self.input_dim = 1280
if pretrain_model_name == 'procyon':
self.input_dim = 4096
if pretrain_model_name == 'prollama':
self.input_dim = 4096
if pretrain_model_name == 'prost':
self.input_dim = 512
if pretrain_model_name == 'gearnet':
self.input_dim = 3072
if pretrain_model_name == 'venusplm':
self.input_dim = 1024
if pretrain_model_name == 'prosst2048':
self.input_dim = 768
if pretrain_model_name == 'prott5':
self.input_dim = 1024
if pretrain_model_name == 'dplm':
self.input_dim = 1280
if pretrain_model_name == 'dplm_150m':
self.input_dim = 640
if pretrain_model_name == 'dplm_3b':
self.input_dim = 2560
if pretrain_model_name == 'ontoprotein':
self.input_dim = 1024
if pretrain_model_name == 'ankh_base':
self.input_dim = 768
if pretrain_model_name == 'pglm':
self.input_dim = 2048
if pretrain_model_name == "pglm-3b":
self.input_dim = 2560
self.smiles_proj = nn.Sequential(
nn.Linear(2048, hid_dim),
# nn.GELU()
)
self.proj = nn.Sequential(
nn.Linear(self.input_dim, hid_dim),
# nn.LayerNorm(hid_dim)
)
self.layernorm = nn.LayerNorm(hid_dim)
if finetune_type == 'adapter':
self.adapter = TransformerAdapter(
input_dim=hid_dim, # 输入维度
hidden_dim=hid_dim, # 隐藏层维度
num_layers=6, # Transformer 层数
num_heads=20, # 多头注意力头数
)
elif finetune_type == 'peft':
self.pretrain_model_interface = PretrainModelInterface(
pretrain_model_name,
task_type=self.task_type
)
self.pretrain_model_interface.setup_peft(
peft_type=peft_type,
**kwargs
)
self.pretrain_model = self.pretrain_model_interface.pretrain_model.model
if task_type in ['classification', 'residual_classification']:
self.task_head = nn.Linear(hid_dim, num_classes)
self.loss = nn.CrossEntropyLoss()
if task_type in [
"regression",
"pair_regression"
]:
self.task_head = nn.Sequential(nn.Linear(hid_dim, 1),
nn.Flatten(start_dim=0, end_dim=1))
self.loss = nn.MSELoss()
if task_type == 'contact':
self.task_head = ContactPredictionHead(hid_dim)
self.loss = ContatcLoss()
if task_type in [
'binary_classification',
'pair_binary_classification',
'multi_labels_classification',
]:
self.task_head = nn.Linear(hid_dim, num_classes)
self.loss = nn.BCEWithLogitsLoss()
def forward(self, batch):
if self.finetune_type == 'adapter':
labels = batch['label']
attention_mask = batch['attention_mask']
embeddings = batch['embedding']
proj_output = self.proj(embeddings)
proj_output = self.adapter(proj_output, mask=attention_mask)
if batch['smiles'] is not None:
smiles = batch['smiles']
smiles_proj_output = self.smiles_proj(smiles).unsqueeze(1)
smiles_attention_mask = torch.ones(attention_mask.shape[0], 1, device=attention_mask.device).bool()
proj_output = torch.cat((smiles_proj_output, proj_output), dim=1).contiguous()
attention_mask = torch.cat((smiles_attention_mask, attention_mask), dim=-1).contiguous()
elif self.finetune_type == "peft":
out = self.pretrain_model_interface(batch)
embeddings, labels, attention_mask, smiles = out
proj_output = self.proj(embeddings.to(self.proj[0].weight.dtype))
if smiles is not None:
smiles_proj_output = self.smiles_proj(smiles).unsqueeze(1)
smiles_attention_mask = torch.ones(attention_mask.shape[0], 1, device=attention_mask.device).bool()
proj_output = torch.cat((smiles_proj_output, proj_output), dim=1).contiguous()
attention_mask = torch.cat((smiles_attention_mask, attention_mask), dim=-1).contiguous()
proj_output = self.layernorm(proj_output)
if self.task_type == 'contact': # residue-level
logits = self.task_head(proj_output)
loss = self.loss(logits, labels.float(), attention_mask)
return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask}
elif self.task_type == 'residual_classification': # resideu-level
logits = self.task_head(proj_output)
logits = logits[attention_mask]
labels = labels[attention_mask]
loss = self.loss(logits, labels.long())
return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask}
else: # sequence-level
pooled_output = torch.mean(proj_output, dim=1)
logits = self.task_head(pooled_output)
if isinstance(self.loss, nn.BCEWithLogitsLoss):
labels = labels.float()
if labels.ndim == 1:
labels = labels.unsqueeze(1)
elif isinstance(self.loss, nn.CrossEntropyLoss):
# logits = logits.float()
labels = labels.long()
else:
# MSELoss, L1Loss 等
labels = labels.to(logits.dtype)
loss = self.loss(logits, labels)
return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask}
# Transformer Adapter 模块
class TransformerAdapter(nn.Module):
def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
super(TransformerAdapter, self).__init__()
# 定义 Transformer Encoder 层
encoder_layer = nn.TransformerEncoderLayer(
d_model=input_dim, # 输入维度 (embedding size)
nhead=num_heads, # 多头注意力
dim_feedforward=hidden_dim*4, # FFN 中间维度
activation='gelu',
batch_first=True # 使用 batch_first 使 (batch, seq_len, dim) 格式
)
# Transformer Encoder 堆叠 num_layers 层
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=num_layers
)
def forward(self, x, mask=None):
"""
x: 输入嵌入, 形状 (batch_size, seq_len, dim)
mask: 注意力掩码, 形状 (batch_size, seq_len)
"""
# 通过 Transformer Adapter 处理
output = self.transformer_encoder(x, src_key_padding_mask=~mask)
return output
class ContactPredictionHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
hidden_size *= 2
self.activation_func = nn.functional.relu
last_size = hidden_size
self.layers = torch.nn.ModuleList()
self.final_activation = torch.nn.Sigmoid()
for sz in [128, 1]:
this_layer = torch.nn.Linear(last_size, sz, bias=True)
last_size = sz
torch.nn.init.kaiming_uniform_(this_layer.weight, nonlinearity='relu')
torch.nn.init.zeros_(this_layer.bias)
self.layers.append(this_layer)
def forward(self, embeddings, **kwargs):
logits = torch.cat([(embeddings[:,:,None]+embeddings[:,None,:]), torch.max(embeddings[:,:,None], embeddings[:,None,:])], dim=-1)
for i, layer in enumerate(self.layers):
if i > 0:
logits = self.activation_func(logits)
logits = layer(logits)
return logits
class ContatcLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels, attn_masks):
"""
logits: logits Tensor of shape (batch_size, L, L)
labels: Tensor of shape (batch_size, L, L)
attn_masks: Tensor of shape (batch_size, L)
"""
logits = logits.squeeze(-1).float()
batch_size, L, _ = logits.shape
# Create pairwise mask from 1D attention mask
pairwise_mask = (attn_masks.unsqueeze(2) * attn_masks.unsqueeze(1)).bool()
# Exclude positions where |i - j| < 6
idxs = torch.arange(L, device=logits.device)
distance_mask = (idxs.unsqueeze(0) - idxs.unsqueeze(1)).abs() > 6
# Only consider upper triangle
upper_triangle_mask = torch.triu(torch.ones((L, L), dtype=torch.bool, device=logits.device), diagonal=1)
# Combine masks
final_mask = pairwise_mask & distance_mask.unsqueeze(0) & upper_triangle_mask.unsqueeze(0)
# Mask out invalid positions
logits = logits[final_mask]
labels = labels[final_mask]
# Flatten and compute BCEWithLogits loss
loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
return loss
# def metric_eval(pred_y, y, inds, ls, lens):
# tests = []
# t_y = []
# rs = []
# for idx in inds:
# row = idx // lens
# col = idx % lens
# if row >= col:
# continue
# if abs(row - col) <= 6:
# continue
# p = pred_y[idx]
# gt = y[idx]
# tests.append((p,gt))
# if len(tests)>=ls:
# break
# cnt = 0
# for p, gt in tests:
# if gt == 1:
# cnt += 1
# return cnt, ls, cnt/ls
# def contact_metrics(preds, labels, attn_masks):
# '''
# pred, label: [B, L, L]
# '''
# total_acc = 0
# valid_samples = 0
# for b in range(preds.shape[0]):
# pred = preds[b]
# label = labels[b]
# mask = attn_masks[b]==1
# pred = pred[:mask.sum(), :mask.sum()]
# label = label[:mask.sum(), :mask.sum()]
# label[label>0] = -1
# label[label==0] = 1
# label[label==-1] = 0
# pred = pred.reshape(-1)
# label = label.reshape(-1)
# indices = torch.argsort(-pred)
# l = label.shape[-1]
# _,_, acc = metric_eval(pred, label, indices, l//5, l)
# total_acc += acc
# valid_samples += 1
# return {"Top(L/5)": total_acc / valid_samples if valid_samples > 0 else 0.0}
def top_L_div_5_precision(preds, labels, attn_masks):
"""
preds: logits Tensor of shape (batch_size, L, L)
labels: Tensor of shape (batch_size, L, L)
attn_masks: Tensor of shape (batch_size, L)
"""
batch_size, L, _ = preds.shape
precisions = []
# Precompute static masks
idxs = torch.arange(L, device=preds.device)
distance_mask = (idxs.unsqueeze(0) - idxs.unsqueeze(1)).abs() >= 6
upper_triangle_mask = torch.triu(torch.ones((L, L), dtype=torch.bool, device=preds.device), diagonal=1)
combined_static_mask = distance_mask & upper_triangle_mask
for b in range(batch_size):
pred = preds[b] # (L, L)
label = labels[b] # (L, L)
mask = attn_masks[b] # (L,)
# Only consider valid positions
valid_mask = (mask.unsqueeze(0) * mask.unsqueeze(1)).bool()
combined_mask = valid_mask & combined_static_mask
pred_scores = pred[combined_mask].flatten()
true_labels = label[combined_mask].flatten()
# Apply sigmoid to logits to get probabilities
pred_probs = torch.sigmoid(pred_scores)
# Top L/5
num_top = max(1, L // 5)
if pred_probs.numel() < num_top:
num_top = pred_probs.numel()
topk = torch.topk(pred_probs, k=num_top)
top_indices = topk.indices
top_true = true_labels[top_indices]
precision = top_true.sum().float() / num_top
precisions.append(precision)
return {'Top(L/5)': torch.stack(precisions).mean()}