| """ |
| MTL Peptide Classifier - PDeepPP Architecture |
| All 19 peptide activity datasets trained jointly with frozen ESM-2 backbone. |
| |
| This is the model architecture used for Original_MTL_19tasks_aggressive training. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import EsmModel, EsmTokenizer |
| from torch.utils.data import Dataset, DataLoader |
| import pandas as pd |
| import numpy as np |
| from pathlib import Path |
| from typing import Dict, List |
| import random |
|
|
|
|
| |
| |
| |
|
|
| class MTLPeptideClassifier(nn.Module): |
| """ |
| Multi-Task Learning classifier for peptide activities. |
| Based on PDeepPP architecture with frozen ESM-2 backbone. |
| |
| Architecture: |
| - Frozen ESM-2 (650M params) as base encoder |
| - Learnable base embedding for amino acids |
| - Weighted combination of ESM-2 and base embeddings |
| - Shared transformer encoder for global context [ablatable] |
| - Shared CNN for local features [ablatable] |
| - Task-specific classification heads |
| |
| Ablation flags: |
| use_transformer: include shared transformer encoder (default True) |
| use_cnn: include shared CNN branch (default True) |
| unfreeze_esm: allow ESM-2 gradients to flow (default False) |
| """ |
|
|
| def __init__( |
| self, |
| task_configs: Dict[str, Dict], |
| hidden_dim: int = 1280, |
| esm_ratio: float = 0.9, |
| num_transformer_layers: int = 4, |
| dropout: float = 0.3, |
| use_transformer: bool = True, |
| use_cnn: bool = True, |
| unfreeze_esm: bool = False, |
| ): |
| """ |
| Args: |
| task_configs: {task_name: {'num_classes': int, 'csv_prefix': str}} |
| hidden_dim: Hidden dimension for base embedding |
| esm_ratio: Weight for ESM-2 vs base embedding (0-1) |
| num_transformer_layers: Layers in shared transformer |
| dropout: Dropout rate |
| use_transformer: Enable shared transformer encoder |
| use_cnn: Enable shared CNN branch |
| unfreeze_esm: Unfreeze ESM-2 backbone for fine-tuning |
| """ |
| super().__init__() |
|
|
| self.use_transformer = use_transformer |
| self.use_cnn = use_cnn |
| self.unfreeze_esm = unfreeze_esm |
|
|
| |
| self.esm = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| if unfreeze_esm: |
| self.esm.requires_grad_(True) |
| else: |
| self.esm.requires_grad_(False) |
|
|
| |
| self.base_embed = nn.Embedding(33, hidden_dim) |
| self.esm_ratio = esm_ratio |
|
|
| |
| if use_transformer: |
| self.transformer = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=8, |
| dim_feedforward=hidden_dim * 4, |
| dropout=dropout, |
| batch_first=True |
| ), |
| num_layers=num_transformer_layers |
| ) |
|
|
| if use_cnn: |
| self.cnn = nn.Conv1d( |
| hidden_dim, |
| hidden_dim, |
| kernel_size=7, |
| padding=3 |
| ) |
| self.layer_norm = nn.LayerNorm(hidden_dim) |
|
|
| |
| if use_transformer and use_cnn: |
| self.feature_dim = hidden_dim * 2 |
| else: |
| self.feature_dim = hidden_dim |
|
|
| |
| self.heads = nn.ModuleDict() |
| for name, cfg in task_configs.items(): |
| self.heads[name] = SequenceHead( |
| input_dim=self.feature_dim, |
| num_classes=cfg['num_classes'], |
| dropout=dropout |
| ) |
|
|
| esm_status = "Unfrozen" if unfreeze_esm else "Frozen" |
| branches = [] |
| if use_transformer: |
| branches.append(f"{num_transformer_layers}-layer Transformer") |
| if use_cnn: |
| branches.append("CNN") |
| if not branches: |
| branches.append("Pass-through (embedding only)") |
| print(f"* MTL Model initialized with {len(task_configs)} tasks") |
| print(f" - ESM-2: {esm_status} (650M params), ratio={esm_ratio}") |
| print(f" - Base Embedding: {hidden_dim} dim") |
| print(f" - Shared Backbone: {' + '.join(branches)}") |
| print(f" - Feature dim: {self.feature_dim}") |
| print(f" - Task Heads: {len(task_configs)} sequence-level") |
|
|
| def encode(self, input_ids, attention_mask): |
| """ |
| Encode sequences through shared backbone. |
| Returns: [B, L, feature_dim] |
| - feature_dim = 2*hidden_dim when both Transformer + CNN active |
| - feature_dim = hidden_dim when only one branch active |
| """ |
| |
| if self.unfreeze_esm: |
| esm_out = self.esm(input_ids, attention_mask).last_hidden_state |
| else: |
| with torch.no_grad(): |
| esm_out = self.esm(input_ids, attention_mask).last_hidden_state |
|
|
| |
| base_out = self.base_embed(input_ids) |
|
|
| |
| x = self.esm_ratio * esm_out + (1 - self.esm_ratio) * base_out |
|
|
| |
| if self.use_transformer and self.use_cnn: |
| |
| global_feat = self.transformer(x) |
| local_feat = self.cnn(x.transpose(1, 2)).transpose(1, 2) |
| local_feat = self.layer_norm(local_feat) |
| shared_repr = torch.cat([global_feat, local_feat], dim=-1) |
| elif self.use_transformer: |
| |
| shared_repr = self.transformer(x) |
| elif self.use_cnn: |
| |
| shared_repr = self.cnn(x.transpose(1, 2)).transpose(1, 2) |
| shared_repr = self.layer_norm(shared_repr) |
| else: |
| |
| shared_repr = x |
|
|
| return shared_repr |
|
|
| def forward(self, input_ids, attention_mask, task_name): |
| """ |
| Forward pass for specific task. |
| Args: |
| input_ids: [B, L] token IDs |
| attention_mask: [B, L] attention mask |
| task_name: which task head to use |
| Returns: |
| logits: [B, num_classes] |
| """ |
| shared_repr = self.encode(input_ids, attention_mask) |
| return self.heads[task_name](shared_repr, attention_mask) |
|
|
| def get_trainable_params(self): |
| """Return count of trainable parameters.""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| class SequenceHead(nn.Module): |
| """ |
| Sequence-level classification head with masked pooling. |
| For binary peptide activity classification. |
| """ |
|
|
| def __init__(self, input_dim: int, num_classes: int = 2, dropout: float = 0.3): |
| super().__init__() |
| self.fc = nn.Sequential( |
| nn.Linear(input_dim, 256), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(256, 128), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(128, num_classes) |
| ) |
|
|
| def forward(self, x, attention_mask): |
| """ |
| Args: |
| x: [B, L, D] shared representation |
| attention_mask: [B, L] mask for pooling |
| Returns: |
| logits: [B, num_classes] |
| """ |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| x_masked = x * mask_expanded |
| x_pooled = x_masked.sum(dim=1) / mask_expanded.sum(dim=1).clamp(min=1e-9) |
|
|
| return self.fc(x_pooled) |
|
|
|
|
| |
| |
| |
|
|
| class PeptideDataset(Dataset): |
| """Single peptide dataset for MTL training.""" |
|
|
| def __init__(self, csv_path: str, tokenizer, max_length: int = 128): |
| df = pd.read_csv(csv_path) |
| |
| seq_col = 'sequence' if 'sequence' in df.columns else 'Sequence' |
| label_col = 'label' if 'label' in df.columns else 'Label' |
|
|
| |
| df = df.dropna(subset=[seq_col, label_col]) |
|
|
| |
| df[seq_col] = df[seq_col].astype(str) |
| df = df[df[seq_col] != 'nan'] |
|
|
| self.sequences = df[seq_col].tolist() |
| self.labels = df[label_col].tolist() |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.sequences) |
|
|
| def __getitem__(self, idx): |
| sequence = str(self.sequences[idx]) |
| label = int(self.labels[idx]) |
|
|
| |
| tokens = " ".join(list(sequence)) |
| encoded = self.tokenizer( |
| tokens, |
| max_length=self.max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
|
|
| return { |
| 'input_ids': encoded['input_ids'].squeeze(0), |
| 'attention_mask': encoded['attention_mask'].squeeze(0), |
| 'label': torch.tensor(label, dtype=torch.long) |
| } |
|
|
|
|
| class MultiTaskDataLoader: |
| """ |
| Multi-task dataloader with task sampling. |
| Samples batches from random tasks each iteration. |
| """ |
|
|
| def __init__(self, task_datasets: Dict[str, Dataset], batch_size: int = 16): |
| """ |
| Args: |
| task_datasets: {task_name: Dataset} |
| batch_size: batch size per task |
| """ |
| self.task_loaders = {} |
| import platform |
| nw = 0 if platform.system() == "Windows" else 2 |
| for task_name, dataset in task_datasets.items(): |
| self.task_loaders[task_name] = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=nw, |
| pin_memory=True |
| ) |
| self.task_names = list(task_datasets.keys()) |
| self.batch_size = batch_size |
|
|
| def __iter__(self): |
| """Yield batches from randomly sampled tasks.""" |
| |
| iters = {name: iter(loader) for name, loader in self.task_loaders.items()} |
|
|
| while iters: |
| |
| task = random.choice(list(iters.keys())) |
|
|
| try: |
| batch = next(iters[task]) |
| yield batch, task |
| except StopIteration: |
| |
| del iters[task] |
| if len(iters) == 0: |
| break |
|
|
| def __len__(self): |
| |
| return sum(len(loader) for loader in self.task_loaders.values()) |
|
|
| def get_task_batches(self, task_name: str): |
| """Get all batches for a specific task (for validation).""" |
| return list(self.task_loaders[task_name]) |
|
|
|
|
| |
| |
| |
|
|
| def get_all_peptide_tasks(data_dir: str) -> Dict[str, Dict]: |
| """ |
| Auto-detect all 20 peptide tasks from UniDL4BioPep data directory. |
| Returns task_configs for MTL model. |
| |
| 20 Tasks: |
| 1. ACE_inhibitory - ACE inhibitory activity |
| 2. DPPIV_inhibitory - DPPIV inhibitory activity |
| 3. Bitter - Bitter taste peptides |
| 4. Umami - Umami taste peptides |
| 5. Antimicrobial - Antimicrobial activity |
| 6. Antimalarial - Antimalarial activity (main) |
| 7. Antimalarial_alt - Antimalarial activity (alternative) |
| 8. Quorum_sensing - Quorum sensing activity |
| 9. Anticancer - Anticancer activity (main) |
| 10. Anticancer_alt - Anticancer activity (alternative) |
| 11. AntiMRSA - Anti-MRSA strains activity |
| 12. TTCA - Therapeutic peptides for cancer |
| 13. BBP - Blood-Brain Barrier peptides |
| 14. Anti_parasitic - Anti-parasitic peptides |
| 15. NeuroPred - Neuroprotective peptides |
| 16. Antibacterial - Antibacterial peptides |
| 17. Antifungal - Antifungal peptides |
| 18. Antiviral - Antiviral peptides |
| 19. Toxicity - Toxicity prediction |
| 20. Anti_inflammatory - Anti-inflammatory peptides |
| """ |
| data_path = Path(data_dir) |
|
|
| |
| task_mappings = { |
| "1__ACE_inhibitory_activity": "ACE_inhibitory", |
| "2__DPPIV_inhibitory_activity": "DPPIV_inhibitory", |
| "3__Bitter": "Bitter", |
| "4__Umami": "Umami", |
| "5__Antimicrobial_activity": "Antimicrobial", |
| "6__Antimalarial_activity-main": "Antimalarial", |
| "6__Antimalarial_activity-alternative": "Antimalarial_alt", |
| "7__Quorum_sensing_activity": "Quorum_sensing", |
| "8__ACP_Anticancer_activity-main": "Anticancer", |
| "8__ACP_Anticancer_activity-alternative": "Anticancer_alt", |
| "9__Anti-MRSA_strains_activity": "AntiMRSA", |
| "10__TTCA": "TTCA", |
| "11__BBP_Blood-Brain_Barrier_Peptides": "BBP", |
| "12__APP__Anti-parasitic": "Anti_parasitic", |
| "13_NeuroPred": "NeuroPred", |
| "14__antibacterial_AB": "Antibacterial", |
| "15__Antifungal_AF": "Antifungal", |
| "16__AV_Antiviral": "Antiviral", |
| "17__Toxicity_2021_Dataset": "Toxicity", |
| "18__Anti_inflammatory_peptides": "Anti_inflammatory", |
| "19__Signal_peptides": "Signal_peptide", |
| "21__Antioxidant_FRS": "Antioxidant" |
| } |
|
|
| task_configs = {} |
| for csv_file in data_path.glob("*_train.csv"): |
| prefix = csv_file.stem.replace("_train", "") |
|
|
| if prefix in task_mappings: |
| task_name = task_mappings[prefix] |
|
|
| |
| df = pd.read_csv(csv_file) |
| label_col = 'label' if 'label' in df.columns else 'Label' |
| n_classes = df[label_col].nunique() if label_col in df.columns else 2 |
| n_classes = max(n_classes, 2) |
|
|
| task_configs[task_name] = { |
| 'num_classes': n_classes, |
| 'csv_prefix': prefix |
| } |
|
|
| return task_configs |
|
|
|
|
| |
| |
| |
|
|
| class TIMLoss(nn.Module): |
| """ |
| Threshold-Independent Multi-task Loss for imbalanced datasets. |
| Reference: https://arxiv.org/abs/2008.10599 |
| |
| Uses learnable task-specific weights (log variances) to balance |
| losses across tasks with different scales and difficulties. |
| """ |
|
|
| def __init__(self, num_tasks: int): |
| super().__init__() |
| |
| self.log_vars = nn.Parameter(torch.zeros(num_tasks)) |
|
|
| def forward(self, losses: torch.Tensor, task_indices: torch.Tensor): |
| """ |
| Args: |
| losses: [B] per-sample losses |
| task_indices: [B] which task each sample belongs to |
| Returns: |
| weighted_loss: scalar |
| """ |
| |
| precision = torch.exp(-self.log_vars) |
|
|
| |
| weighted_losses = [] |
| for i, loss in enumerate(losses): |
| task_idx = task_indices[i].item() |
| weighted_loss = precision[task_idx] * loss + self.log_vars[task_idx] |
| weighted_losses.append(weighted_loss) |
|
|
| return torch.stack(weighted_losses).mean() |
|
|
|
|
| |
| |
| |
|
|
| def create_model_and_loaders( |
| data_dir: str, |
| batch_size: int = 16, |
| max_length: int = 128 |
| ): |
| """Create MTL model, datasets, and dataloaders.""" |
|
|
| |
| task_configs = get_all_peptide_tasks(data_dir) |
| print(f"\n* Detected {len(task_configs)} peptide tasks:") |
| for name, cfg in task_configs.items(): |
| print(f" - {name}: {cfg['num_classes']} classes") |
|
|
| |
| tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
| |
| model = MTLPeptideClassifier( |
| task_configs=task_configs, |
| hidden_dim=1280, |
| esm_ratio=0.9, |
| num_transformer_layers=4, |
| dropout=0.3 |
| ) |
|
|
| |
| train_datasets = {} |
| val_datasets = {} |
|
|
| for task_name, cfg in task_configs.items(): |
| prefix = cfg['csv_prefix'] |
| train_path = Path(data_dir) / f"{prefix}_train.csv" |
| val_path = Path(data_dir) / f"{prefix}_test.csv" |
|
|
| if train_path.exists(): |
| train_datasets[task_name] = PeptideDataset( |
| str(train_path), |
| tokenizer, |
| max_length |
| ) |
| if val_path.exists(): |
| val_datasets[task_name] = PeptideDataset( |
| str(val_path), |
| tokenizer, |
| max_length |
| ) |
|
|
| |
| train_loader = MultiTaskDataLoader(train_datasets, batch_size) |
|
|
| print(f"\n* Created dataloaders:") |
| print(f" - Train tasks: {len(train_datasets)}") |
| print(f" - Val tasks: {len(val_datasets)}") |
| print(f" - Approx batches/epoch: {len(train_loader)}") |
|
|
| |
| trainable = model.get_trainable_params() |
| total = sum(p.numel() for p in model.parameters()) |
| print(f"\n* Model parameters:") |
| print(f" - Total: {total:,}") |
| print(f" - Trainable: {trainable:,} ({100*trainable/total:.2f}%)") |
|
|
| return model, train_loader, val_datasets, task_configs |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| script_dir = Path(__file__).parent |
| data_dir = str(script_dir / "datasets") |
|
|
| model, train_loader, val_datasets, task_configs = create_model_and_loaders( |
| data_dir, |
| batch_size=16 |
| ) |
|
|
| |
| print("\n" + "="*60) |
| print("Testing forward pass...") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
|
|
| |
| for batch, task_name in train_loader: |
| print(f" Task: {task_name}") |
| print(f" Batch size: {batch['input_ids'].shape[0]}") |
|
|
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
|
|
| with torch.no_grad(): |
| logits = model(input_ids, attention_mask, task_name) |
|
|
| print(f" Logits shape: {logits.shape}") |
| print(" * Forward pass successful!") |
| break |
|
|