A protein interaction face prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning. Model deployment references Synthira's [fastESM] (https://huggingface.co/Synthyra) series.
The dataset comes from the [MaSIF project] (https://github.com/LPDI-EPFL/masif).
This model is trained solely for auxiliary purposes and is not recommended for direct use. For predicting interaction faces, the 150M and 650M ESM2 models demonstrate significantly superior performance compared to the 8M model.
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
model_id = "leexiaohua/iface_small"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()
import torch
import numpy as np
import trimesh
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification
from scipy.spatial import KDTree
from sklearn.metrics import roc_auc_score
def evaluate_protein_interface(pdb_path, ply_path, model_id="leexiaohua/iface_small"):
"""
Evaluates the model's interface prediction performance against ground truth from a PLY file.
Args:
pdb_path (str): Path to the .pdb file.
ply_path (str): Path to the .ply file containing 'iface' metadata.
model_id (str): Hugging Face model ID.
Returns:
dict: A dictionary containing ROC-AUC, sequence length, and number of interface residues.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Load Model and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForTokenClassification.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()
# 2. Load Mesh Data from PLY
mesh = trimesh.load(ply_path)
ply_coords = np.array(mesh.vertices)
# Extract 'iface' values from vertex metadata
iface_values = np.array(mesh.metadata['_ply_raw']['vertex']['data']['iface']).flatten()
# 3. Parse PDB for sequence and CA coordinates
res_coords = []
res_nums = []
sequence = ""
d3to1 = {
'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'
}
with open(pdb_path, 'r') as f:
for line in f:
if line.startswith("ATOM") and " CA " in line:
res_name = line[17:20].strip()
x = float(line[30:38])
y = float(line[38:46])
z = float(line[46:54])
res_num = line[22:26].strip()
res_coords.append([x, y, z])
res_nums.append(res_num)
sequence += d3to1.get(res_name, 'X')
res_coords = np.array(res_coords)
# 4. Map PLY vertices to PDB residues using KDTree
tree = KDTree(res_coords)
dist, indices = tree.query(ply_coords)
df = pd.DataFrame({
'res_num': [res_nums[i] for i in indices],
'iface': iface_values
})
# Group by residue and take the maximum interface value as ground truth
res_result = df.groupby('res_num')['iface'].max()
# 5. Model Inference
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get probabilities for class 1 (interface)
all_probs = torch.sigmoid(outputs.logits)[0, :, 1].cpu().numpy()
# Remove CLS and EOS tokens to align with sequence
y_score = all_probs[1:-1]
# 6. Align Ground Truth Labels with Scores
y_true = []
for res_idx_str in res_nums:
if res_idx_str in res_result:
val = 1 if res_result[res_idx_str] > 0.5 else 0
y_true.append(val)
else:
y_true.append(0)
y_true = np.array(y_true)
# 7. Calculate Metrics
results = {
"seq_len": len(sequence),
"num_interface_res": int(sum(y_true)),
"roc_auc": None
}
if len(y_true) == len(y_score):
roc_auc = roc_auc_score(y_true, y_score)
results["roc_auc"] = roc_auc
print(f"\n--- Evaluation Results ---")
print(f"Sequence Length: {results['seq_len']}")
print(f"Ground Truth Interface Residues: {results['num_interface_res']}")
print(f"โ
ROC-AUC Score: {roc_auc:.4f}")
else:
print(f"โ Length Mismatch! y_true: {len(y_true)}, y_score: {len(y_score)}")
return results
An example:
res = evaluate_protein_interface(
pdb_path='/scratch/lb4489/project/ESM/1EWY.pdb',
ply_path='/scratch/lb4489/project/ESM/1EWY_A.ply')
The output will be similar to:
--- Evaluation Results ---
Sequence Length: 704
Ground Truth Interface Residues: 109
โ
ROC-AUC Score: 0.6412
- Downloads last month
- 113
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support
