MTL Peptide Classifier (22 Tasks)
Multi-Task Learning peptide classifier covering 22 binary peptide-activity tasks. Built on a frozen ESM-2 (650M) backbone with a parallel Transformer + CNN feature extractor and per-task heads, following a PDeepPP-inspired design.
Held-out Test Set Performance (Averaged across 22 tasks)
| Metric | Value |
|---|---|
| Accuracy | 86.63% |
| F1 | 84.99% |
| AUC | 93.47% |
| MCC | 72.67% |
Best Val Avg F1 (used for checkpoint selection): 85.56%
Per-Task Test Metrics
| Task | ACC | F1 | AUC | MCC |
|---|---|---|---|---|
| AntiMRSA | 0.9899 | 0.9667 | 0.9970 | 0.9607 |
| Anticancer | 0.7035 | 0.7344 | 0.8007 | 0.4184 |
| ACE_inhibitory | 0.6801 | 0.7392 | 0.8355 | 0.4040 |
| Antioxidant | 0.7117 | 0.7309 | 0.8173 | 0.4382 |
| Bitter | 0.8203 | 0.8456 | 0.9591 | 0.6782 |
| Antimalarial | 0.9736 | 0.7692 | 0.9177 | 0.7579 |
| Anti_inflammatory | 0.9886 | 0.9887 | 0.9979 | 0.9773 |
| Antimicrobial | 0.9746 | 0.9552 | 0.9915 | 0.9379 |
| Signal_peptide | 0.9927 | 0.9927 | 0.9997 | 0.9854 |
| Antifungal | 0.9465 | 0.9456 | 0.9863 | 0.8935 |
| Antimalarial_alt | 0.9877 | 0.9630 | 0.9942 | 0.9566 |
| Anticancer_alt | 0.9330 | 0.9316 | 0.9784 | 0.8667 |
| Anti_parasitic | 0.7826 | 0.7500 | 0.9216 | 0.5855 |
| Umami | 0.8427 | 0.7308 | 0.9297 | 0.6243 |
| Quorum_sensing | 0.9250 | 0.9231 | 0.9850 | 0.8511 |
| Antibacterial | 0.9431 | 0.9424 | 0.9789 | 0.8863 |
| NeuroPred | 0.8660 | 0.8543 | 0.9444 | 0.7416 |
| Toxicity | 0.9086 | 0.8971 | 0.9699 | 0.8178 |
| Antiviral | 0.8307 | 0.8319 | 0.9098 | 0.6614 |
| DPPIV_inhibitory | 0.8647 | 0.8732 | 0.9478 | 0.7361 |
| BBP | 0.6579 | 0.5185 | 0.9141 | 0.3873 |
| TTCA | 0.7360 | 0.8129 | 0.7858 | 0.4221 |
Architecture
- Shared encoder: frozen ESM-2 (
facebook/esm2_t33_650M_UR50D, 650M params) + learnable base embedding, mixed atesm_ratio=0.9 - Feature extraction (parallel): 4-layer Transformer + CNN (kernel=7, padding=3) β concatenated to 2560-dim features
- Heads: 22 binary classifiers (
2560 β 256 β 128 β 2) with masked average pooling - Loss: TIM (Threshold-Independent Multi-task) loss + label smoothing 0.1
Tasks
| # | Task | Source |
|---|---|---|
| 1 | ACE_inhibitory | UniDL4BioPep |
| 2 | DPPIV_inhibitory | UniDL4BioPep |
| 3 | Bitter | UniDL4BioPep |
| 4 | Umami | UniDL4BioPep |
| 5 | Antimicrobial | UniDL4BioPep |
| 6 | Antimalarial (main) | UniDL4BioPep |
| 7 | Antimalarial_alt | UniDL4BioPep |
| 8 | Quorum_sensing | UniDL4BioPep |
| 9 | Anticancer (main) | UniDL4BioPep |
| 10 | Anticancer_alt | UniDL4BioPep |
| 11 | AntiMRSA | UniDL4BioPep |
| 12 | TTCA | UniDL4BioPep |
| 13 | BBP | UniDL4BioPep |
| 14 | Anti_parasitic | UniDL4BioPep |
| 15 | NeuroPred | UniDL4BioPep |
| 16 | Antibacterial | UniDL4BioPep |
| 17 | Antifungal | UniDL4BioPep |
| 18 | Antiviral | UniDL4BioPep |
| 19 | Toxicity | UniDL4BioPep |
| 20 | Anti_inflammatory | local dataset |
| 21 | Signal_peptide | local dataset |
| 22 | Antioxidant | UniDL4BioPep (antioxidant_FRS) |
Usage
import os
from huggingface_hub import hf_hub_download
import torch
from transformers import EsmTokenizer
from mtl_peptide_classifier import MTLPeptideClassifier, get_all_peptide_tasks
REPO = "tuankg1028/MTL-Peptide-Classifier"
checkpoint_dir = "MTL-Peptide-Classifier"
os.makedirs(checkpoint_dir, exist_ok=True)
for fname in ["heads.pt", "shared_backbone.pt", "ablation_config.json"]:
hf_hub_download(repo_id=REPO, filename=fname, local_dir=checkpoint_dir)
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
task_configs = get_all_peptide_tasks("datasets") # needs local datasets/ dir for task names
model = MTLPeptideClassifier(
task_configs=task_configs,
hidden_dim=1280,
esm_ratio=0.9,
num_transformer_layers=4,
dropout=0.3,
use_transformer=True,
use_cnn=True,
unfreeze_esm=False,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
backbone = torch.load(f"{checkpoint_dir}/shared_backbone.pt", map_location=device)
heads = torch.load(f"{checkpoint_dir}/heads.pt", map_location=device)
model.base_embed.load_state_dict(backbone["base_embed"])
if "transformer" in backbone:
model.transformer.load_state_dict(backbone["transformer"])
if "cnn" in backbone:
model.cnn.load_state_dict(backbone["cnn"])
model.layer_norm.load_state_dict(backbone["layer_norm"])
for name, head in model.heads.items():
if name in heads:
head.load_state_dict(heads[name])
model = model.to(device).eval()
sequence = "MKWVTFISLLFLFSSAYSRGVFRR"
tokens = " ".join(list(sequence))
inputs = tokenizer(tokens, return_tensors="pt", max_length=128, padding="max_length", truncation=True)
with torch.no_grad():
logits = model(inputs["input_ids"].to(device), inputs["attention_mask"].to(device), task_name="Antimicrobial")
probs = torch.softmax(logits, dim=-1)
Training
- Base model:
facebook/esm2_t33_650M_UR50D(frozen) - Batch size: 16, learning rate: 1e-4, 50 epochs, dropout: 0.3
- 3-way split per task: 80% train / 20% val (checkpoint selection) / held-out test CSV evaluated once
- Mixed precision, gradient clipping 1.0, cosine LR with 5 warmup epochs
- TIM loss + label smoothing 0.1
Files
heads.ptβ per-task classifier headsshared_backbone.ptβ base embedding, Transformer, CNN, LayerNormablation_config.jsonβ architecture configuration for reproducibilitytest_results.jsonβ held-out test metrics (per task + averages)mtl_peptide_classifier.pyβ model code
Requirements
torch>=2.0.0
transformers>=4.30.0
huggingface_hub
numpy
pandas
scikit-learn
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support
Model tree for tuankg1028/MTL-Peptide-Classifier
Base model
facebook/esm2_t33_650M_UR50D