|
|
--- |
|
|
license: mit |
|
|
base_model: |
|
|
- esmc_600m |
|
|
tags: |
|
|
- protein |
|
|
- antibody |
|
|
- esmc |
|
|
- biology |
|
|
- CDR |
|
|
--- |
|
|
|
|
|
# AbCDR-ESMC: Antibody ESMC Paired Model |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model is a fine-tuned version of ESMC-600M (ESM Cambrian) for paired antibody sequences (heavy and light chains). |
|
|
|
|
|
**Key Features:** |
|
|
- Trained on paired antibody sequences |
|
|
- 50% CDR fine-tuning |
|
|
- Input format: Heavy-Light chains separated by "-" |
|
|
- Output: 1152-dimensional embeddings |
|
|
- Optimized for antibody CDR region understanding |
|
|
|
|
|
### Preprocessing |
|
|
|
|
|
Sequences were: |
|
|
1. Combined as: HEAVY-LIGHT (with "-" separator) |
|
|
2. Uncommon amino acids replaced with X |
|
|
3. Tokenized with ESMC tokenizer |
|
|
4. CDR regions annotated for masking |
|
|
|
|
|
## Installation & Requirements |
|
|
```bash |
|
|
pip install torch |
|
|
pip install safetensors |
|
|
pip install huggingface_hub |
|
|
pip install esm==3.1.4 |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Loading the Model |
|
|
```python |
|
|
import os |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from esm.tokenization import get_esmc_model_tokenizers |
|
|
from esm.models.esmc import ESMC |
|
|
from safetensors import safe_open |
|
|
|
|
|
# Configuration |
|
|
REPO_ID = "MahTala/AbCDR-ESMC" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Load tokenizer and base model |
|
|
tokenizer = get_esmc_model_tokenizers() |
|
|
model = ESMC.from_pretrained("esmc_600m").to(device) |
|
|
|
|
|
# Download fine-tuned weights |
|
|
local_ckpt_path = hf_hub_download( |
|
|
repo_id=REPO_ID, |
|
|
filename="model.safetensors", |
|
|
token=os.getenv("HF_TOKEN", None) # For private repos |
|
|
) |
|
|
|
|
|
# Load and rename state dict |
|
|
original_state_dict = {} |
|
|
with safe_open(local_ckpt_path, framework="pt") as sf: |
|
|
for key in sf.keys(): |
|
|
original_state_dict[key] = sf.get_tensor(key) |
|
|
|
|
|
# Remove "esmC_model." prefix |
|
|
renamed_state_dict = {} |
|
|
for key, value in original_state_dict.items(): |
|
|
new_key = key.replace("esmC_model.", "") if key.startswith("esmC_model.") else key |
|
|
renamed_state_dict[new_key] = value |
|
|
|
|
|
# Load weights |
|
|
model.load_state_dict(renamed_state_dict, strict=False) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### Extract Embeddings - Method 1 (High-Level API) |
|
|
```python |
|
|
from esm.sdk.api import ESMProtein, LogitsConfig |
|
|
|
|
|
SEP_TOKEN = "-" |
|
|
|
|
|
# Example sequences |
|
|
heavy_chain = ( |
|
|
"EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVAVISYDGSNKYYADSVKGRF" |
|
|
"TISADTSKNTAYLQMNSLRAEDTAVYYCAREGYYGSSYWYFDYWGQGTLVTVSS" |
|
|
) |
|
|
light_chain = ( |
|
|
"DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGS" |
|
|
"GTDFTLTISSLQPEDFATYYCQQSYSTPLTFGGGTKVEIK" |
|
|
) |
|
|
|
|
|
# Combine with separator |
|
|
paired_sequence = f"{heavy_chain}{SEP_TOKEN}{light_chain}" |
|
|
|
|
|
# Create protein object and encode |
|
|
protein = ESMProtein(sequence=paired_sequence) |
|
|
protein_tensor = model.encode(protein) |
|
|
|
|
|
# Get embeddings |
|
|
logits_output = model.logits( |
|
|
protein_tensor, |
|
|
LogitsConfig(sequence=True, return_embeddings=True) |
|
|
) |
|
|
|
|
|
embeddings = logits_output.embeddings # Shape: (1, seq_len, 1152) |
|
|
logits = logits_output.logits.sequence # Shape: (1, seq_len, 64) |
|
|
|
|
|
print(f"Embeddings shape: {embeddings.shape}") # (1, L, 1152) |
|
|
print(f"Embeddings dtype: {embeddings.dtype}") # float32 |
|
|
``` |
|
|
|
|
|
### Extract Embeddings - Method 2 (Low-Level Direct) |
|
|
```python |
|
|
# Tokenize sequence |
|
|
seq_encoded = tokenizer(paired_sequence, return_tensors="pt") |
|
|
seq_input_ids = seq_encoded["input_ids"].to(device) |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
outputs = model(sequence_tokens=seq_input_ids) |
|
|
|
|
|
embeddings_direct = outputs.embeddings # Shape: (1, seq_len, 1152) |
|
|
logits_direct = outputs.sequence_logits # Shape: (1, seq_len, 64) |
|
|
|
|
|
print(f"Embeddings shape: {embeddings_direct.shape}") # (1, L, 1152) |
|
|
print(f"Embeddings dtype: {embeddings_direct.dtype}") # bfloat16 |
|
|
``` |
|
|
|
|
|
### Mean Pooling for Fixed-Size Representation |
|
|
```python |
|
|
# Mean pooling over sequence length |
|
|
sequence_representation = embeddings_direct.mean(dim=1) # (1, 1152) |
|
|
print(f"Pooled embedding shape: {sequence_representation.shape}") |
|
|
|
|
|
# Get interface embedding (at separator position) |
|
|
separator_pos = len(heavy_chain) |
|
|
interface_embedding = embeddings_direct[0, separator_pos, :] # (1152,) |
|
|
``` |
|
|
|
|
|
### Batch Processing |
|
|
```python |
|
|
# Multiple sequences |
|
|
sequences = [ |
|
|
f"{heavy_chain}{SEP_TOKEN}{light_chain}", |
|
|
f"{heavy_chain[:100]}{SEP_TOKEN}{light_chain[:100]}", |
|
|
] |
|
|
|
|
|
# Tokenize with padding |
|
|
batch_encoded = tokenizer(sequences, return_tensors="pt", padding=True) |
|
|
batch_input_ids = batch_encoded["input_ids"].to(device) |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
batch_outputs = model(sequence_tokens=batch_input_ids) |
|
|
|
|
|
batch_embeddings = batch_outputs.embeddings # (batch_size, max_seq_len, 1152) |
|
|
print(f"Batch embeddings shape: {batch_embeddings.shape}") |
|
|
``` |
|
|
|
|
|
## Input Format |
|
|
|
|
|
**Required Format:** `HEAVY_CHAIN-LIGHT_CHAIN` |
|
|
|
|
|
- Heavy and light chains must be separated by hyphen (`-`) |
|
|
- Use standard single-letter amino acid codes |
|
|
- No spaces in sequence |
|
|
|
|
|
**Example:** |
|
|
```python |
|
|
sequence = "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMS...-DIQMTQSPSSLSASVGDRVTITCRASQSISS..." |
|
|
``` |
|
|
|
|
|
## Output |
|
|
|
|
|
### Embeddings |
|
|
- **Dimension:** 1152 (ESMC hidden size) |
|
|
- **Sequence length:** Variable (up to model's max length) |
|
|
- **Format:** PyTorch tensor |
|
|
- **Dtype:** |
|
|
- High-level API: float32 |
|
|
- Low-level API: bfloat16 |
|
|
|
|
|
### Logits |
|
|
- **Dimension:** 64 (ESMC vocabulary size) |
|
|
- **Format:** PyTorch tensor |
|
|
- **Dtype:** bfloat16 |
|
|
|
|
|
|
|
|
## Citation |
|
|
```bibtex |
|
|
@article{Talaei2025.10.31.685149, |
|
|
author = {Talaei, Mahtab and Walker, Kenji C. and Hao, Boran and Jolley, Eliot and Jin, Yeping and Kozakov, Dima and Misasi, John and Vajda, Sandor and Paschalidis, Ioannis Ch. and Joseph-McCarthy, Diane}, |
|
|
title = {CDR-aware masked language models for paired antibodies enable state-of-the-art binding prediction}, |
|
|
year = {2025}, |
|
|
doi = {10.1101/2025.10.31.685149}, |
|
|
eprint = {https://www.biorxiv.org/content/early/2025/10/31/2025.10.31.685149.full.pdf}, |
|
|
journal = {bioRxiv} |
|
|
} |
|
|
|
|
|
@article{hayes2024simulating, |
|
|
title={Simulating 500 million years of evolution with a language model}, |
|
|
author={Hayes, Thomas and Rao, Roshan and Akin, Halil and Sofroniew, Nicholas J and Oktay, Deniz and Lin, Zeming and Verkuil, Robert and Tran, Vincent Q and Deaton, Jonathan and Wiggert, Marius and others}, |
|
|
journal={bioRxiv}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Model Card Authors |
|
|
|
|
|
Mahtab Talaei |
|
|
|
|
|
## Contact |
|
|
|
|
|
- **Maintainer:** Network Optimization & Control (NOC) Lab |
|
|
- **Email:** mtalaei@bu.edu |
|
|
- **GitHub:** [https://github.com/Mah-Tala/AbCDR-ESM](https://github.com/Mah-Tala/AbCDR-ESM) |
|
|
- **Paper:** [bioRxiv preprint](https://www.biorxiv.org/content/10.1101/2025.10.31.685149v1) |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released under the MIT License. |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Base model: ESMC (ESM Cambrian) by EvolutionaryScale |
|
|
- Data: OAS database |
|
|
|
|
|
--- |
|
|
|
|
|
**Note:** For private repositories, you'll need to authenticate: |
|
|
```bash |
|
|
# Option 1: CLI login |
|
|
huggingface-cli login |
|
|
|
|
|
# Option 2: Environment variable |
|
|
export HF_TOKEN="your_token_here" |
|
|
``` |