WSobo's picture
Update app.py
060ccef verified
import os
import urllib.request
import importlib.util
import gradio as gr
import torch
import numpy as np
from torch_geometric.data import HeteroData
from torch_geometric.nn import radius_graph, radius
# Import your model class (Make sure model_utils.py is in your Space!)
from model_utils import Struct2SeqGNN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ---------------------------------------------------------
# 1. DOWNLOAD & LOAD MODEL WEIGHTS
# ---------------------------------------------------------
raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt"
model_path = "best_model.pt"
if not os.path.exists(model_path):
print("Downloading model weights from GitHub...")
urllib.request.urlretrieve(raw_github_url, model_path)
# Instantiate the model matching your v2.0 training parameters
model = Struct2SeqGNN(
node_features=6,
ligand_features=6,
hidden_dim=256,
num_classes=21,
num_layers=6,
dropout=0.0
).to(device)
# Load the weights with DDP prefix handling
state_dict = torch.load(model_path, map_location=device)
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# Standard Amino Acid alphabet
AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
# ---------------------------------------------------------
# 2. DATA PROCESSING PIPELINE (PyG HeteroData)
# ---------------------------------------------------------
def _load_ligandmpnn_parsers():
"""Load LigandMPNN parser functions directly from the HF Space root."""
parser_file = "LigandMPNN_data_utils.py"
if not os.path.exists(parser_file):
raise ImportError(
"Could not find data_utils.py. "
"Please upload the LigandMPNN data_utils.py file to the root of your Hugging Face Space."
)
spec = importlib.util.spec_from_file_location("ligandmpnn_data_utils", parser_file)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module.parse_PDB, module.featurize
parse_PDB, featurize = _load_ligandmpnn_parsers()
def get_ligandmpnn_features(pdb_path, device="cpu"):
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(pdb_path, device=device)
if "chain_letters" in protein_dict:
protein_dict["chain_mask"] = torch.ones(
len(protein_dict["chain_letters"]),
dtype=torch.int32,
device=device
)
feature_dict = featurize(protein_dict, cutoff_for_score=8.0)
feature_dict["ligand_Y"] = protein_dict.get("Y", None)
feature_dict["ligand_Y_t"] = protein_dict.get("Y_t", None)
feature_dict["ligand_Y_m"] = protein_dict.get("Y_m", None)
return feature_dict
def compute_dihedrals(X):
N = X[:, 0, :]
CA = X[:, 1, :]
C = X[:, 2, :]
C_prev = torch.cat([C[0:1], C[:-1]], dim=0)
N_next = torch.cat([N[1:], N[-1:]], dim=0)
CA_next = torch.cat([CA[1:], CA[-1:]], dim=0)
def dihedral(p0, p1, p2, p3):
b0 = p0 - p1
b1 = p2 - p1
b2 = p3 - p2
b1_norm = b1 / (torch.linalg.norm(b1, dim=-1, keepdim=True) + 1e-7)
n1 = torch.linalg.cross(b0, b1_norm, dim=-1)
n2 = torch.linalg.cross(b1_norm, b2, dim=-1)
m = torch.linalg.cross(n1, b1_norm, dim=-1)
x = torch.sum(n1 * n2, dim=-1)
y = torch.sum(m * n2, dim=-1)
return torch.atan2(y, x)
phi = dihedral(C_prev, N, CA, C)
psi = dihedral(N, CA, C, N_next)
omega = dihedral(CA, C, N_next, CA_next)
dihedrals = torch.stack([phi, psi, omega], dim=-1)
return torch.cat([torch.sin(dihedrals), torch.cos(dihedrals)], dim=-1)
def encode_ligand_elements(element_ids):
M = element_ids.shape[0]
one_hot = torch.zeros((M, 6), dtype=torch.float32, device=element_ids.device)
mask_C = (element_ids == 6)
mask_N = (element_ids == 7)
mask_O = (element_ids == 8)
mask_S = (element_ids == 16)
mask_P = (element_ids == 15)
one_hot[mask_C, 0] = 1.0
one_hot[mask_N, 1] = 1.0
one_hot[mask_O, 2] = 1.0
one_hot[mask_S, 3] = 1.0
one_hot[mask_P, 4] = 1.0
mask_other = ~(mask_C | mask_N | mask_O | mask_S | mask_P)
one_hot[mask_other, 5] = 1.0
return one_hot
def dict_to_pyg_data(feature_dict, radius_cutoff=8.0):
data = HeteroData()
# 1. Build Protein Nodes
X = feature_dict["X"].squeeze(0)
if X.dim() == 3 and X.size(1) >= 4:
ca_coords = X[:, 1, :]
else:
ca_coords = X
sequence_labels = feature_dict["S"].squeeze(0)
mask = feature_dict["mask"].squeeze(0).bool()
dihedral_features = compute_dihedrals(X)
ca_coords = ca_coords[mask]
sequence_labels = sequence_labels[mask]
dihedral_features = dihedral_features[mask]
data['protein'].x = dihedral_features.clone().float()
data['protein'].pos = ca_coords.clone().float()
data['protein'].y = sequence_labels.long()
if "chain_M" in feature_dict:
data['protein'].chain_M = feature_dict["chain_M"].squeeze(0)[mask]
p_pos = data['protein'].pos
pp_edge_index = radius_graph(p_pos, r=radius_cutoff, loop=False)
p_row, p_col = pp_edge_index
pp_dist = torch.norm(p_pos[p_row] - p_pos[p_col], dim=1, p=2).unsqueeze(-1)
data['protein', 'interacts_with', 'protein'].edge_index = pp_edge_index
data['protein', 'interacts_with', 'protein'].edge_attr = pp_dist
# 2. Build Ligand Nodes
Y = feature_dict.get("ligand_Y")
Y_t = feature_dict.get("ligand_Y_t")
Y_m = feature_dict.get("ligand_Y_m")
num_ligand_atoms = 0
if Y is not None and Y_m is not None:
Y_mask = Y_m.bool()
if Y_mask.sum() > 0:
Y = Y[Y_mask]
Y_t = Y_t[Y_mask]
num_ligand_atoms = Y.shape[0]
lig_x = encode_ligand_elements(Y_t)
data['ligand'].x = lig_x
data['ligand'].pos = Y.float()
if num_ligand_atoms > 0:
l_pos = data['ligand'].pos
pl_edge_index = radius(l_pos, p_pos, r=radius_cutoff)
if pl_edge_index.size(1) > 0:
p_idx, l_idx = pl_edge_index[0], pl_edge_index[1]
lp_edge_index = torch.stack([l_idx, p_idx], dim=0)
lp_dist = torch.norm(l_pos[l_idx] - p_pos[p_idx], dim=1, p=2).unsqueeze(-1)
data['ligand', 'binds', 'protein'].edge_index = lp_edge_index
data['ligand', 'binds', 'protein'].edge_attr = lp_dist
pl_edge_index_rev = torch.stack([p_idx, l_idx], dim=0)
data['protein', 'binds', 'ligand'].edge_index = pl_edge_index_rev
data['protein', 'binds', 'ligand'].edge_attr = lp_dist.clone()
else:
data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
else:
data['ligand'].x = torch.empty((0, 6), dtype=torch.float32)
data['ligand'].pos = torch.empty((0, 3), dtype=torch.float32)
data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
return data
def pdb_to_pyg_data(pdb_path, radius=8.0, device="cpu"):
feature_dict = get_ligandmpnn_features(pdb_path, device=device)
data = dict_to_pyg_data(feature_dict, radius_cutoff=radius)
return data
# ---------------------------------------------------------
# 3. INFERENCE ENDPOINT
# ---------------------------------------------------------
def predict_sequence(pdb_file):
if pdb_file is None:
return "Please upload a .pdb file."
try:
# Build the Heterogeneous Graph
data = pdb_to_pyg_data(pdb_file.name, device=device)
data = data.to(device)
num_residues = data['protein'].x.shape[0]
# Run the forward pass
with torch.no_grad():
logits = model(data)
# Decode logits to an amino acid string
predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy()
predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices])
return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}"
except Exception as e:
return f"Error processing PDB: {str(e)}"
# ---------------------------------------------------------
# 4. GRADIO UI
# ---------------------------------------------------------
demo = gr.Interface(
fn=predict_sequence,
inputs=gr.File(label="Upload Target Protein Backbone (.pdb)", file_types=[".pdb"]),
outputs=gr.Textbox(label="Designed Amino Acid Sequence", lines=5),
title="Struct2Seq-GNN: Inverse Protein Folding",
description=(
"Upload a 3D target backbone to generate a sequence optimized by a custom Heterogeneous Graph Neural Network.\n\n"
"**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** "
"on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints."
)
)
if __name__ == "__main__":
demo.launch()