| --- |
| license: mit |
| base_model: facebook/esm2_t33_650M_UR50D |
| tags: |
| - biology |
| - peptide |
| - multi-task-learning |
| - protein |
| - classification |
| --- |
| |
| # MTL Peptide Classifier (22 Tasks) |
|
|
| Multi-Task Learning peptide classifier covering 22 binary peptide-activity tasks. Built on a frozen ESM-2 (650M) backbone with a parallel Transformer + CNN feature extractor and per-task heads, following a PDeepPP-inspired design. |
|
|
| ## Architecture |
|
|
| - **Shared encoder**: frozen ESM-2 (`facebook/esm2_t33_650M_UR50D`, 650M params) + learnable base embedding, mixed at `esm_ratio=0.9` |
| - **Feature extraction (parallel)**: 4-layer Transformer + CNN (kernel=7, padding=3) β concatenated to 2560-dim features |
| - **Heads**: 22 binary classifiers (`2560 β 256 β 128 β 2`) with masked average pooling |
| - **Loss**: TIM (Threshold-Independent Multi-task) loss + label smoothing 0.1 |
|
|
| ## Tasks |
|
|
| | # | Task | Source | |
| |---|---|---| |
| | 1 | ACE_inhibitory | UniDL4BioPep | |
| | 2 | DPPIV_inhibitory | UniDL4BioPep | |
| | 3 | Bitter | UniDL4BioPep | |
| | 4 | Umami | UniDL4BioPep | |
| | 5 | Antimicrobial | UniDL4BioPep | |
| | 6 | Antimalarial (main) | UniDL4BioPep | |
| | 7 | Antimalarial_alt | UniDL4BioPep | |
| | 8 | Quorum_sensing | UniDL4BioPep | |
| | 9 | Anticancer (main) | UniDL4BioPep | |
| | 10 | Anticancer_alt | UniDL4BioPep | |
| | 11 | AntiMRSA | UniDL4BioPep | |
| | 12 | TTCA | UniDL4BioPep | |
| | 13 | BBP | UniDL4BioPep | |
| | 14 | Anti_parasitic | UniDL4BioPep | |
| | 15 | NeuroPred | UniDL4BioPep | |
| | 16 | Antibacterial | UniDL4BioPep | |
| | 17 | Antifungal | UniDL4BioPep | |
| | 18 | Antiviral | UniDL4BioPep | |
| | 19 | Toxicity | UniDL4BioPep | |
| | 20 | Anti_inflammatory | local dataset | |
| | 21 | Signal_peptide | local dataset | |
| | 22 | Antioxidant | UniDL4BioPep (antioxidant_FRS) | |
| |
| ## Usage |
| |
| ```python |
| import os |
| from huggingface_hub import hf_hub_download |
| import torch |
| from transformers import EsmTokenizer |
|
|
| from mtl_peptide_classifier import MTLPeptideClassifier, get_all_peptide_tasks |
| |
| REPO = "tuankg1028/MTL-Peptide-Classifier" |
| checkpoint_dir = "MTL-Peptide-Classifier" |
| os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
| for fname in ["heads.pt", "shared_backbone.pt", "ablation_config.json"]: |
| hf_hub_download(repo_id=REPO, filename=fname, local_dir=checkpoint_dir) |
| |
| tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| task_configs = get_all_peptide_tasks("datasets") # needs local datasets/ dir for task names |
|
|
| model = MTLPeptideClassifier( |
| task_configs=task_configs, |
| hidden_dim=1280, |
| esm_ratio=0.9, |
| num_transformer_layers=4, |
| dropout=0.3, |
| use_transformer=True, |
| use_cnn=True, |
| unfreeze_esm=False, |
| ) |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| backbone = torch.load(f"{checkpoint_dir}/shared_backbone.pt", map_location=device) |
| heads = torch.load(f"{checkpoint_dir}/heads.pt", map_location=device) |
|
|
| model.base_embed.load_state_dict(backbone["base_embed"]) |
| if "transformer" in backbone: |
| model.transformer.load_state_dict(backbone["transformer"]) |
| if "cnn" in backbone: |
| model.cnn.load_state_dict(backbone["cnn"]) |
| model.layer_norm.load_state_dict(backbone["layer_norm"]) |
| for name, head in model.heads.items(): |
| if name in heads: |
| head.load_state_dict(heads[name]) |
| |
| model = model.to(device).eval() |
|
|
| sequence = "MKWVTFISLLFLFSSAYSRGVFRR" |
| tokens = " ".join(list(sequence)) |
| inputs = tokenizer(tokens, return_tensors="pt", max_length=128, padding="max_length", truncation=True) |
| with torch.no_grad(): |
| logits = model(inputs["input_ids"].to(device), inputs["attention_mask"].to(device), task_name="Antimicrobial") |
| probs = torch.softmax(logits, dim=-1) |
| ``` |
| |
| ## Training |
|
|
| - Base model: `facebook/esm2_t33_650M_UR50D` (frozen) |
| - Batch size: 16, learning rate: 1e-4, 50 epochs, dropout: 0.3 |
| - 3-way split per task: 80% train / 20% val (checkpoint selection) / held-out test CSV evaluated once |
| - Mixed precision, gradient clipping 1.0, cosine LR with 5 warmup epochs |
| - TIM loss + label smoothing 0.1 |
|
|
| ## Files |
|
|
| - `heads.pt` β per-task classifier heads |
| - `shared_backbone.pt` β base embedding, Transformer, CNN, LayerNorm |
| - `ablation_config.json` β architecture configuration for reproducibility |
| - `test_results.json` β held-out test metrics (per task + averages) |
| - `mtl_peptide_classifier.py` β model code. |
|
|
| ## Requirements |
|
|
| ``` |
| torch>=2.0.0 |
| transformers>=4.30.0 |
| huggingface_hub |
| numpy |
| pandas |
| scikit-learn |
| ``` |
|
|