| | import torch |
| | import torch.nn as nn |
| | from src.model.pretrain_model_interface import PretrainModelInterface |
| |
|
| | class UniModel(nn.Module): |
| | def __init__( |
| | self, |
| | pretrain_model_name: str, |
| | task_type: str, |
| | finetune_type: str, |
| | num_classes: int, |
| | peft_type: str = "lora", |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.pretrain_model_name = pretrain_model_name |
| | self.task_type = task_type |
| | self.finetune_type = finetune_type |
| | hid_dim = 480 |
| | if pretrain_model_name == 'esm2_35m': |
| | self.input_dim = 480 |
| | |
| | if pretrain_model_name == 'esm2_150m': |
| | self.input_dim = 640 |
| | |
| | if pretrain_model_name == 'esm2_650m': |
| | self.input_dim = 1280 |
| | |
| | if pretrain_model_name == 'esm2_3b': |
| | self.input_dim = 2560 |
| | |
| | if pretrain_model_name == 'esm2_15b': |
| | self.input_dim = 5120 |
| | |
| | if pretrain_model_name == 'esm3_1.4b': |
| | self.input_dim = 1536 |
| | |
| | if pretrain_model_name == 'esmc_600m': |
| | self.input_dim = 1152 |
| | |
| | if pretrain_model_name == 'progen2': |
| | self.input_dim = 1536 |
| | |
| | if pretrain_model_name == 'prostt5': |
| | self.input_dim = 2048 |
| | |
| | if pretrain_model_name == 'protgpt2': |
| | self.input_dim = 1280 |
| | |
| | if pretrain_model_name == 'protrek_35m': |
| | self.input_dim = 480*2 |
| | |
| | if pretrain_model_name == 'protrek': |
| | self.input_dim = 1920 |
| | |
| | if pretrain_model_name == 'saport': |
| | self.input_dim = 1280 |
| |
|
| | if pretrain_model_name == 'saport_35m': |
| | self.input_dim = 480 |
| | |
| | if pretrain_model_name == 'saport_1.3b': |
| | self.input_dim = 1280 |
| |
|
| | if pretrain_model_name == 'procyon': |
| | self.input_dim = 4096 |
| |
|
| | if pretrain_model_name == 'prollama': |
| | self.input_dim = 4096 |
| | |
| | if pretrain_model_name == 'prost': |
| | self.input_dim = 512 |
| | |
| | if pretrain_model_name == 'gearnet': |
| | self.input_dim = 3072 |
| | |
| | if pretrain_model_name == 'venusplm': |
| | self.input_dim = 1024 |
| |
|
| | if pretrain_model_name == 'prosst2048': |
| | self.input_dim = 768 |
| |
|
| | if pretrain_model_name == 'prott5': |
| | self.input_dim = 1024 |
| | |
| | if pretrain_model_name == 'dplm': |
| | self.input_dim = 1280 |
| |
|
| | if pretrain_model_name == 'dplm_150m': |
| | self.input_dim = 640 |
| |
|
| | if pretrain_model_name == 'dplm_3b': |
| | self.input_dim = 2560 |
| | |
| | if pretrain_model_name == 'ontoprotein': |
| | self.input_dim = 1024 |
| |
|
| | if pretrain_model_name == 'ankh_base': |
| | self.input_dim = 768 |
| | |
| | if pretrain_model_name == 'pglm': |
| | self.input_dim = 2048 |
| | |
| | if pretrain_model_name == "pglm-3b": |
| | self.input_dim = 2560 |
| | |
| |
|
| | self.smiles_proj = nn.Sequential( |
| | nn.Linear(2048, hid_dim), |
| | |
| | ) |
| | self.proj = nn.Sequential( |
| | nn.Linear(self.input_dim, hid_dim), |
| | |
| | ) |
| | self.layernorm = nn.LayerNorm(hid_dim) |
| | if finetune_type == 'adapter': |
| | self.adapter = TransformerAdapter( |
| | input_dim=hid_dim, |
| | hidden_dim=hid_dim, |
| | num_layers=6, |
| | num_heads=20, |
| | ) |
| | elif finetune_type == 'peft': |
| | self.pretrain_model_interface = PretrainModelInterface( |
| | pretrain_model_name, |
| | task_type=self.task_type |
| | ) |
| | self.pretrain_model_interface.setup_peft( |
| | peft_type=peft_type, |
| | **kwargs |
| | ) |
| | self.pretrain_model = self.pretrain_model_interface.pretrain_model.model |
| | |
| | if task_type in ['classification', 'residual_classification']: |
| | self.task_head = nn.Linear(hid_dim, num_classes) |
| | self.loss = nn.CrossEntropyLoss() |
| | |
| | if task_type in [ |
| | "regression", |
| | "pair_regression" |
| | ]: |
| | self.task_head = nn.Sequential(nn.Linear(hid_dim, 1), |
| | nn.Flatten(start_dim=0, end_dim=1)) |
| | self.loss = nn.MSELoss() |
| | |
| | |
| | if task_type == 'contact': |
| | self.task_head = ContactPredictionHead(hid_dim) |
| | self.loss = ContatcLoss() |
| | |
| | if task_type in [ |
| | 'binary_classification', |
| | 'pair_binary_classification', |
| | 'multi_labels_classification', |
| | ]: |
| | self.task_head = nn.Linear(hid_dim, num_classes) |
| | self.loss = nn.BCEWithLogitsLoss() |
| | |
| | def forward(self, batch): |
| | if self.finetune_type == 'adapter': |
| | labels = batch['label'] |
| | attention_mask = batch['attention_mask'] |
| | embeddings = batch['embedding'] |
| | proj_output = self.proj(embeddings) |
| | proj_output = self.adapter(proj_output, mask=attention_mask) |
| | if batch['smiles'] is not None: |
| | smiles = batch['smiles'] |
| | smiles_proj_output = self.smiles_proj(smiles).unsqueeze(1) |
| | smiles_attention_mask = torch.ones(attention_mask.shape[0], 1, device=attention_mask.device).bool() |
| | proj_output = torch.cat((smiles_proj_output, proj_output), dim=1).contiguous() |
| | attention_mask = torch.cat((smiles_attention_mask, attention_mask), dim=-1).contiguous() |
| |
|
| | elif self.finetune_type == "peft": |
| | out = self.pretrain_model_interface(batch) |
| | embeddings, labels, attention_mask, smiles = out |
| | proj_output = self.proj(embeddings.to(self.proj[0].weight.dtype)) |
| | if smiles is not None: |
| | smiles_proj_output = self.smiles_proj(smiles).unsqueeze(1) |
| | smiles_attention_mask = torch.ones(attention_mask.shape[0], 1, device=attention_mask.device).bool() |
| | proj_output = torch.cat((smiles_proj_output, proj_output), dim=1).contiguous() |
| | attention_mask = torch.cat((smiles_attention_mask, attention_mask), dim=-1).contiguous() |
| | |
| | proj_output = self.layernorm(proj_output) |
| | |
| | if self.task_type == 'contact': |
| | logits = self.task_head(proj_output) |
| | loss = self.loss(logits, labels.float(), attention_mask) |
| | return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask} |
| | elif self.task_type == 'residual_classification': |
| | logits = self.task_head(proj_output) |
| | logits = logits[attention_mask] |
| | labels = labels[attention_mask] |
| | loss = self.loss(logits, labels.long()) |
| | return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask} |
| | else: |
| | pooled_output = torch.mean(proj_output, dim=1) |
| | logits = self.task_head(pooled_output) |
| | if isinstance(self.loss, nn.BCEWithLogitsLoss): |
| | labels = labels.float() |
| | if labels.ndim == 1: |
| | labels = labels.unsqueeze(1) |
| | elif isinstance(self.loss, nn.CrossEntropyLoss): |
| | |
| | labels = labels.long() |
| | else: |
| | |
| | labels = labels.to(logits.dtype) |
| | loss = self.loss(logits, labels) |
| | return {'loss': loss, 'logits': logits, 'label': labels, 'attention_mask': attention_mask} |
| | |
| | |
| | class TransformerAdapter(nn.Module): |
| | def __init__(self, input_dim, hidden_dim, num_layers, num_heads): |
| | super(TransformerAdapter, self).__init__() |
| |
|
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=input_dim, |
| | nhead=num_heads, |
| | dim_feedforward=hidden_dim*4, |
| | activation='gelu', |
| | batch_first=True |
| | ) |
| | |
| | |
| | self.transformer_encoder = nn.TransformerEncoder( |
| | encoder_layer, num_layers=num_layers |
| | ) |
| |
|
| | def forward(self, x, mask=None): |
| | """ |
| | x: 输入嵌入, 形状 (batch_size, seq_len, dim) |
| | mask: 注意力掩码, 形状 (batch_size, seq_len) |
| | """ |
| | |
| | |
| | output = self.transformer_encoder(x, src_key_padding_mask=~mask) |
| | |
| | return output |
| |
|
| | class ContactPredictionHead(nn.Module): |
| | def __init__(self, hidden_size): |
| | super().__init__() |
| | hidden_size *= 2 |
| | self.activation_func = nn.functional.relu |
| | last_size = hidden_size |
| | self.layers = torch.nn.ModuleList() |
| | self.final_activation = torch.nn.Sigmoid() |
| | for sz in [128, 1]: |
| | this_layer = torch.nn.Linear(last_size, sz, bias=True) |
| | last_size = sz |
| | torch.nn.init.kaiming_uniform_(this_layer.weight, nonlinearity='relu') |
| | torch.nn.init.zeros_(this_layer.bias) |
| | self.layers.append(this_layer) |
| |
|
| |
|
| | def forward(self, embeddings, **kwargs): |
| | logits = torch.cat([(embeddings[:,:,None]+embeddings[:,None,:]), torch.max(embeddings[:,:,None], embeddings[:,None,:])], dim=-1) |
| | for i, layer in enumerate(self.layers): |
| | if i > 0: |
| | logits = self.activation_func(logits) |
| | logits = layer(logits) |
| | return logits |
| |
|
| |
|
| | class ContatcLoss(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, logits, labels, attn_masks): |
| | """ |
| | logits: logits Tensor of shape (batch_size, L, L) |
| | labels: Tensor of shape (batch_size, L, L) |
| | attn_masks: Tensor of shape (batch_size, L) |
| | """ |
| | logits = logits.squeeze(-1).float() |
| | batch_size, L, _ = logits.shape |
| |
|
| | |
| | pairwise_mask = (attn_masks.unsqueeze(2) * attn_masks.unsqueeze(1)).bool() |
| |
|
| | |
| | idxs = torch.arange(L, device=logits.device) |
| | distance_mask = (idxs.unsqueeze(0) - idxs.unsqueeze(1)).abs() > 6 |
| |
|
| | |
| | upper_triangle_mask = torch.triu(torch.ones((L, L), dtype=torch.bool, device=logits.device), diagonal=1) |
| |
|
| | |
| | final_mask = pairwise_mask & distance_mask.unsqueeze(0) & upper_triangle_mask.unsqueeze(0) |
| |
|
| | |
| | logits = logits[final_mask] |
| | labels = labels[final_mask] |
| |
|
| | |
| | loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.float()) |
| | return loss |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def top_L_div_5_precision(preds, labels, attn_masks): |
| | """ |
| | preds: logits Tensor of shape (batch_size, L, L) |
| | labels: Tensor of shape (batch_size, L, L) |
| | attn_masks: Tensor of shape (batch_size, L) |
| | """ |
| | batch_size, L, _ = preds.shape |
| | precisions = [] |
| |
|
| | |
| | idxs = torch.arange(L, device=preds.device) |
| | distance_mask = (idxs.unsqueeze(0) - idxs.unsqueeze(1)).abs() >= 6 |
| | upper_triangle_mask = torch.triu(torch.ones((L, L), dtype=torch.bool, device=preds.device), diagonal=1) |
| | combined_static_mask = distance_mask & upper_triangle_mask |
| |
|
| | for b in range(batch_size): |
| | pred = preds[b] |
| | label = labels[b] |
| | mask = attn_masks[b] |
| |
|
| | |
| | valid_mask = (mask.unsqueeze(0) * mask.unsqueeze(1)).bool() |
| |
|
| | combined_mask = valid_mask & combined_static_mask |
| |
|
| | pred_scores = pred[combined_mask].flatten() |
| | true_labels = label[combined_mask].flatten() |
| |
|
| | |
| | pred_probs = torch.sigmoid(pred_scores) |
| |
|
| | |
| | num_top = max(1, L // 5) |
| | if pred_probs.numel() < num_top: |
| | num_top = pred_probs.numel() |
| | topk = torch.topk(pred_probs, k=num_top) |
| | top_indices = topk.indices |
| |
|
| | top_true = true_labels[top_indices] |
| | precision = top_true.sum().float() / num_top |
| | precisions.append(precision) |
| |
|
| | return {'Top(L/5)': torch.stack(precisions).mean()} |
| |
|
| |
|