A protein interaction face prediction model based on [ESM2-8M model] (https://www.science.org/doi/full/10.1126/science.ade2574) fine-tuning. Model deployment references Synthira's [fastESM] (https://huggingface.co/Synthyra) series.

The dataset comes from the [MaSIF project] (https://github.com/LPDI-EPFL/masif).

This model is trained solely for auxiliary purposes and is not recommended for direct use. For predicting interaction faces, the 150M and 650M ESM2 models demonstrate significantly superior performance compared to the 8M model.

evaluation_metrics

import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

model_id = "leexiaohua/iface_small"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForTokenClassification.from_pretrained(model_id, trust_remote_code=True)
model.to(device)
model.eval()
import torch
import numpy as np
import trimesh
import pandas as pd
from transformers import AutoTokenizer, AutoModelForTokenClassification
from scipy.spatial import KDTree
from sklearn.metrics import roc_auc_score

def evaluate_protein_interface(pdb_path, ply_path, model_id="leexiaohua/iface_small"):
    """
    Evaluates the model's interface prediction performance against ground truth from a PLY file.
    
    Args:
        pdb_path (str): Path to the .pdb file.
        ply_path (str): Path to the .ply file containing 'iface' metadata.
        model_id (str): Hugging Face model ID.
        
    Returns:
        dict: A dictionary containing ROC-AUC, sequence length, and number of interface residues.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Model and Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForTokenClassification.from_pretrained(model_id, trust_remote_code=True)
    model.to(device)
    model.eval()

    # 2. Load Mesh Data from PLY
    mesh = trimesh.load(ply_path)
    ply_coords = np.array(mesh.vertices)
    # Extract 'iface' values from vertex metadata
    iface_values = np.array(mesh.metadata['_ply_raw']['vertex']['data']['iface']).flatten()

    # 3. Parse PDB for sequence and CA coordinates
    res_coords = []
    res_nums = []
    sequence = ""
    d3to1 = {
        'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
        'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
        'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
        'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'
    }

    with open(pdb_path, 'r') as f:
        for line in f:
            if line.startswith("ATOM") and " CA " in line:
                res_name = line[17:20].strip()
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                res_num = line[22:26].strip()
                
                res_coords.append([x, y, z])
                res_nums.append(res_num)
                sequence += d3to1.get(res_name, 'X')

    res_coords = np.array(res_coords)

    # 4. Map PLY vertices to PDB residues using KDTree
    tree = KDTree(res_coords)
    dist, indices = tree.query(ply_coords)
    df = pd.DataFrame({
        'res_num': [res_nums[i] for i in indices],
        'iface': iface_values
    })

    # Group by residue and take the maximum interface value as ground truth
    res_result = df.groupby('res_num')['iface'].max()

    # 5. Model Inference
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        # Get probabilities for class 1 (interface)
        all_probs = torch.sigmoid(outputs.logits)[0, :, 1].cpu().numpy()
        # Remove CLS and EOS tokens to align with sequence
        y_score = all_probs[1:-1]

    # 6. Align Ground Truth Labels with Scores
    y_true = []
    for res_idx_str in res_nums:
        if res_idx_str in res_result:
            val = 1 if res_result[res_idx_str] > 0.5 else 0
            y_true.append(val)
        else:
            y_true.append(0)
    y_true = np.array(y_true)

    # 7. Calculate Metrics
    results = {
        "seq_len": len(sequence),
        "num_interface_res": int(sum(y_true)),
        "roc_auc": None
    }

    if len(y_true) == len(y_score):
        roc_auc = roc_auc_score(y_true, y_score)
        results["roc_auc"] = roc_auc
        
        print(f"\n--- Evaluation Results ---")
        print(f"Sequence Length: {results['seq_len']}")
        print(f"Ground Truth Interface Residues: {results['num_interface_res']}")
        print(f"โœ… ROC-AUC Score: {roc_auc:.4f}")
    else:
        print(f"โŒ Length Mismatch! y_true: {len(y_true)}, y_score: {len(y_score)}")
    
    return results

An example:

res = evaluate_protein_interface(
     pdb_path='/scratch/lb4489/project/ESM/1EWY.pdb',
     ply_path='/scratch/lb4489/project/ESM/1EWY_A.ply')

The output will be similar to:

--- Evaluation Results ---
Sequence Length: 704
Ground Truth Interface Residues: 109
โœ… ROC-AUC Score: 0.6412
Downloads last month
113
Safetensors
Model size
8.01M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support