A post-translational modification prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning.
The dataset comes from the [PTM-Mamba project] (https://github.com/programmablebio/ptm-mamba?tab=readme-ov-file).
Modifications of the same type for different amino acids are merged, and rare modifications are merged into the same class.
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
model_id = "leexiaohua/PTM_small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
def probe_ptm(sequence, threshold=0.5):
# 1. Tokenize and move to device
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
inputs = {k: v.to(device) for k, v in inputs.items()}
# 2. Model Inference
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
# 3. Class Mapping
ptm_map = {
1: "Phospho",
2: "Acetyl",
3: "Methyl",
4: "Glycosyl",
5: "Lipid",
6: "Rare"}
print(f"--- Scanning sequence with {threshold:.0%} sensitivity (Bidirectional Alignment) ---")
seq_idx = 0
found = False
for i, token in enumerate(tokens):
# Skip special tokens like <cls>, <sep>, <pad>
if token.startswith("<") or token.endswith(">"):
continue
# Extract probabilities for the current residue
res_probs = probs[0][i]
mod_probs = res_probs[1:] # Exclude 'None' class (index 0)
max_mod_prob, max_mod_idx = torch.max(mod_probs, dim=0)
actual_label_idx = max_mod_idx.item() + 1
if max_mod_prob > threshold:
amino_acid = sequence[seq_idx]
actual_position = seq_idx + 1
# Verify token alignment with original sequence
if token.upper() != amino_acid.upper():
# Potential alignment warning can be added here if needed
pass
print(f"Position {actual_position}: {amino_acid} -> Likely {ptm_map[actual_label_idx]} (Confidence: {max_mod_prob:.2%})")
found = True
seq_idx += 1
if not found:
print("No PTM sites found above the specified threshold.")
test_seq = "MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN"
probe_ptm(test_seq, threshold=0.2)
The output will be similar to:
--- Scanning sequence with 20% sensitivity (Bidirectional Alignment) ---
Position 198: S -> Likely Phospho (Confidence: 74.38%)
Position 206: S -> Likely Phospho (Confidence: 79.13%)
Position 217: Y -> Likely Rare (Confidence: 90.77%)
Position 262: Y -> Likely Rare (Confidence: 59.31%)
Position 336: Y -> Likely Rare (Confidence: 67.14%)
Position 441: S -> Likely Phospho (Confidence: 79.16%)
Position 497: Y -> Likely Phospho (Confidence: 44.33%)
Position 542: N -> Likely Glycosyl (Confidence: 95.32%)
Position 571: N -> Likely Glycosyl (Confidence: 88.64%)
Position 729: T -> Likely Phospho (Confidence: 79.49%)
Position 730: S -> Likely Phospho (Confidence: 95.94%)
Position 757: Y -> Likely Phospho (Confidence: 48.01%)
- Downloads last month
- 82
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
