MTL Peptide Classifier (22 Tasks)

Multi-Task Learning peptide classifier covering 22 binary peptide-activity tasks. Built on a frozen ESM-2 (650M) backbone with a parallel Transformer + CNN feature extractor and per-task heads, following a PDeepPP-inspired design.

Held-out Test Set Performance (Averaged across 22 tasks)

Metric Value
Accuracy 86.63%
F1 84.99%
AUC 93.47%
MCC 72.67%

Best Val Avg F1 (used for checkpoint selection): 85.56%

Per-Task Test Metrics

Task ACC F1 AUC MCC
AntiMRSA 0.9899 0.9667 0.9970 0.9607
Anticancer 0.7035 0.7344 0.8007 0.4184
ACE_inhibitory 0.6801 0.7392 0.8355 0.4040
Antioxidant 0.7117 0.7309 0.8173 0.4382
Bitter 0.8203 0.8456 0.9591 0.6782
Antimalarial 0.9736 0.7692 0.9177 0.7579
Anti_inflammatory 0.9886 0.9887 0.9979 0.9773
Antimicrobial 0.9746 0.9552 0.9915 0.9379
Signal_peptide 0.9927 0.9927 0.9997 0.9854
Antifungal 0.9465 0.9456 0.9863 0.8935
Antimalarial_alt 0.9877 0.9630 0.9942 0.9566
Anticancer_alt 0.9330 0.9316 0.9784 0.8667
Anti_parasitic 0.7826 0.7500 0.9216 0.5855
Umami 0.8427 0.7308 0.9297 0.6243
Quorum_sensing 0.9250 0.9231 0.9850 0.8511
Antibacterial 0.9431 0.9424 0.9789 0.8863
NeuroPred 0.8660 0.8543 0.9444 0.7416
Toxicity 0.9086 0.8971 0.9699 0.8178
Antiviral 0.8307 0.8319 0.9098 0.6614
DPPIV_inhibitory 0.8647 0.8732 0.9478 0.7361
BBP 0.6579 0.5185 0.9141 0.3873
TTCA 0.7360 0.8129 0.7858 0.4221

Architecture

  • Shared encoder: frozen ESM-2 (facebook/esm2_t33_650M_UR50D, 650M params) + learnable base embedding, mixed at esm_ratio=0.9
  • Feature extraction (parallel): 4-layer Transformer + CNN (kernel=7, padding=3) β†’ concatenated to 2560-dim features
  • Heads: 22 binary classifiers (2560 β†’ 256 β†’ 128 β†’ 2) with masked average pooling
  • Loss: TIM (Threshold-Independent Multi-task) loss + label smoothing 0.1

Tasks

# Task Source
1 ACE_inhibitory UniDL4BioPep
2 DPPIV_inhibitory UniDL4BioPep
3 Bitter UniDL4BioPep
4 Umami UniDL4BioPep
5 Antimicrobial UniDL4BioPep
6 Antimalarial (main) UniDL4BioPep
7 Antimalarial_alt UniDL4BioPep
8 Quorum_sensing UniDL4BioPep
9 Anticancer (main) UniDL4BioPep
10 Anticancer_alt UniDL4BioPep
11 AntiMRSA UniDL4BioPep
12 TTCA UniDL4BioPep
13 BBP UniDL4BioPep
14 Anti_parasitic UniDL4BioPep
15 NeuroPred UniDL4BioPep
16 Antibacterial UniDL4BioPep
17 Antifungal UniDL4BioPep
18 Antiviral UniDL4BioPep
19 Toxicity UniDL4BioPep
20 Anti_inflammatory local dataset
21 Signal_peptide local dataset
22 Antioxidant UniDL4BioPep (antioxidant_FRS)

Usage

import os
from huggingface_hub import hf_hub_download
import torch
from transformers import EsmTokenizer

from mtl_peptide_classifier import MTLPeptideClassifier, get_all_peptide_tasks

REPO = "tuankg1028/MTL-Peptide-Classifier"
checkpoint_dir = "MTL-Peptide-Classifier"
os.makedirs(checkpoint_dir, exist_ok=True)

for fname in ["heads.pt", "shared_backbone.pt", "ablation_config.json"]:
    hf_hub_download(repo_id=REPO, filename=fname, local_dir=checkpoint_dir)

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
task_configs = get_all_peptide_tasks("datasets")  # needs local datasets/ dir for task names

model = MTLPeptideClassifier(
    task_configs=task_configs,
    hidden_dim=1280,
    esm_ratio=0.9,
    num_transformer_layers=4,
    dropout=0.3,
    use_transformer=True,
    use_cnn=True,
    unfreeze_esm=False,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = torch.load(f"{checkpoint_dir}/shared_backbone.pt", map_location=device)
heads = torch.load(f"{checkpoint_dir}/heads.pt", map_location=device)

model.base_embed.load_state_dict(backbone["base_embed"])
if "transformer" in backbone:
    model.transformer.load_state_dict(backbone["transformer"])
if "cnn" in backbone:
    model.cnn.load_state_dict(backbone["cnn"])
    model.layer_norm.load_state_dict(backbone["layer_norm"])
for name, head in model.heads.items():
    if name in heads:
        head.load_state_dict(heads[name])

model = model.to(device).eval()

sequence = "MKWVTFISLLFLFSSAYSRGVFRR"
tokens = " ".join(list(sequence))
inputs = tokenizer(tokens, return_tensors="pt", max_length=128, padding="max_length", truncation=True)
with torch.no_grad():
    logits = model(inputs["input_ids"].to(device), inputs["attention_mask"].to(device), task_name="Antimicrobial")
    probs = torch.softmax(logits, dim=-1)

Training

  • Base model: facebook/esm2_t33_650M_UR50D (frozen)
  • Batch size: 16, learning rate: 1e-4, 50 epochs, dropout: 0.3
  • 3-way split per task: 80% train / 20% val (checkpoint selection) / held-out test CSV evaluated once
  • Mixed precision, gradient clipping 1.0, cosine LR with 5 warmup epochs
  • TIM loss + label smoothing 0.1

Files

  • heads.pt β€” per-task classifier heads
  • shared_backbone.pt β€” base embedding, Transformer, CNN, LayerNorm
  • ablation_config.json β€” architecture configuration for reproducibility
  • test_results.json β€” held-out test metrics (per task + averages)
  • mtl_peptide_classifier.py β€” model code

Requirements

torch>=2.0.0
transformers>=4.30.0
huggingface_hub
numpy
pandas
scikit-learn
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for tuankg1028/MTL-Peptide-Classifier

Finetuned
(32)
this model