import os import urllib.request import importlib.util import gradio as gr import torch import numpy as np from torch_geometric.data import HeteroData from torch_geometric.nn import radius_graph, radius # Import your model class (Make sure model_utils.py is in your Space!) from model_utils import Struct2SeqGNN device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # --------------------------------------------------------- # 1. DOWNLOAD & LOAD MODEL WEIGHTS # --------------------------------------------------------- raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt" model_path = "best_model.pt" if not os.path.exists(model_path): print("Downloading model weights from GitHub...") urllib.request.urlretrieve(raw_github_url, model_path) # Instantiate the model matching your v2.0 training parameters model = Struct2SeqGNN( node_features=6, ligand_features=6, hidden_dim=256, num_classes=21, num_layers=6, dropout=0.0 ).to(device) # Load the weights with DDP prefix handling state_dict = torch.load(model_path, map_location=device) if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() # Standard Amino Acid alphabet AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX" # --------------------------------------------------------- # 2. DATA PROCESSING PIPELINE (PyG HeteroData) # --------------------------------------------------------- def _load_ligandmpnn_parsers(): """Load LigandMPNN parser functions directly from the HF Space root.""" parser_file = "LigandMPNN_data_utils.py" if not os.path.exists(parser_file): raise ImportError( "Could not find data_utils.py. " "Please upload the LigandMPNN data_utils.py file to the root of your Hugging Face Space." ) spec = importlib.util.spec_from_file_location("ligandmpnn_data_utils", parser_file) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module.parse_PDB, module.featurize parse_PDB, featurize = _load_ligandmpnn_parsers() def get_ligandmpnn_features(pdb_path, device="cpu"): protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(pdb_path, device=device) if "chain_letters" in protein_dict: protein_dict["chain_mask"] = torch.ones( len(protein_dict["chain_letters"]), dtype=torch.int32, device=device ) feature_dict = featurize(protein_dict, cutoff_for_score=8.0) feature_dict["ligand_Y"] = protein_dict.get("Y", None) feature_dict["ligand_Y_t"] = protein_dict.get("Y_t", None) feature_dict["ligand_Y_m"] = protein_dict.get("Y_m", None) return feature_dict def compute_dihedrals(X): N = X[:, 0, :] CA = X[:, 1, :] C = X[:, 2, :] C_prev = torch.cat([C[0:1], C[:-1]], dim=0) N_next = torch.cat([N[1:], N[-1:]], dim=0) CA_next = torch.cat([CA[1:], CA[-1:]], dim=0) def dihedral(p0, p1, p2, p3): b0 = p0 - p1 b1 = p2 - p1 b2 = p3 - p2 b1_norm = b1 / (torch.linalg.norm(b1, dim=-1, keepdim=True) + 1e-7) n1 = torch.linalg.cross(b0, b1_norm, dim=-1) n2 = torch.linalg.cross(b1_norm, b2, dim=-1) m = torch.linalg.cross(n1, b1_norm, dim=-1) x = torch.sum(n1 * n2, dim=-1) y = torch.sum(m * n2, dim=-1) return torch.atan2(y, x) phi = dihedral(C_prev, N, CA, C) psi = dihedral(N, CA, C, N_next) omega = dihedral(CA, C, N_next, CA_next) dihedrals = torch.stack([phi, psi, omega], dim=-1) return torch.cat([torch.sin(dihedrals), torch.cos(dihedrals)], dim=-1) def encode_ligand_elements(element_ids): M = element_ids.shape[0] one_hot = torch.zeros((M, 6), dtype=torch.float32, device=element_ids.device) mask_C = (element_ids == 6) mask_N = (element_ids == 7) mask_O = (element_ids == 8) mask_S = (element_ids == 16) mask_P = (element_ids == 15) one_hot[mask_C, 0] = 1.0 one_hot[mask_N, 1] = 1.0 one_hot[mask_O, 2] = 1.0 one_hot[mask_S, 3] = 1.0 one_hot[mask_P, 4] = 1.0 mask_other = ~(mask_C | mask_N | mask_O | mask_S | mask_P) one_hot[mask_other, 5] = 1.0 return one_hot def dict_to_pyg_data(feature_dict, radius_cutoff=8.0): data = HeteroData() # 1. Build Protein Nodes X = feature_dict["X"].squeeze(0) if X.dim() == 3 and X.size(1) >= 4: ca_coords = X[:, 1, :] else: ca_coords = X sequence_labels = feature_dict["S"].squeeze(0) mask = feature_dict["mask"].squeeze(0).bool() dihedral_features = compute_dihedrals(X) ca_coords = ca_coords[mask] sequence_labels = sequence_labels[mask] dihedral_features = dihedral_features[mask] data['protein'].x = dihedral_features.clone().float() data['protein'].pos = ca_coords.clone().float() data['protein'].y = sequence_labels.long() if "chain_M" in feature_dict: data['protein'].chain_M = feature_dict["chain_M"].squeeze(0)[mask] p_pos = data['protein'].pos pp_edge_index = radius_graph(p_pos, r=radius_cutoff, loop=False) p_row, p_col = pp_edge_index pp_dist = torch.norm(p_pos[p_row] - p_pos[p_col], dim=1, p=2).unsqueeze(-1) data['protein', 'interacts_with', 'protein'].edge_index = pp_edge_index data['protein', 'interacts_with', 'protein'].edge_attr = pp_dist # 2. Build Ligand Nodes Y = feature_dict.get("ligand_Y") Y_t = feature_dict.get("ligand_Y_t") Y_m = feature_dict.get("ligand_Y_m") num_ligand_atoms = 0 if Y is not None and Y_m is not None: Y_mask = Y_m.bool() if Y_mask.sum() > 0: Y = Y[Y_mask] Y_t = Y_t[Y_mask] num_ligand_atoms = Y.shape[0] lig_x = encode_ligand_elements(Y_t) data['ligand'].x = lig_x data['ligand'].pos = Y.float() if num_ligand_atoms > 0: l_pos = data['ligand'].pos pl_edge_index = radius(l_pos, p_pos, r=radius_cutoff) if pl_edge_index.size(1) > 0: p_idx, l_idx = pl_edge_index[0], pl_edge_index[1] lp_edge_index = torch.stack([l_idx, p_idx], dim=0) lp_dist = torch.norm(l_pos[l_idx] - p_pos[p_idx], dim=1, p=2).unsqueeze(-1) data['ligand', 'binds', 'protein'].edge_index = lp_edge_index data['ligand', 'binds', 'protein'].edge_attr = lp_dist pl_edge_index_rev = torch.stack([p_idx, l_idx], dim=0) data['protein', 'binds', 'ligand'].edge_index = pl_edge_index_rev data['protein', 'binds', 'ligand'].edge_attr = lp_dist.clone() else: data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long) data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32) data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long) data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32) else: data['ligand'].x = torch.empty((0, 6), dtype=torch.float32) data['ligand'].pos = torch.empty((0, 3), dtype=torch.float32) data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long) data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32) data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long) data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32) return data def pdb_to_pyg_data(pdb_path, radius=8.0, device="cpu"): feature_dict = get_ligandmpnn_features(pdb_path, device=device) data = dict_to_pyg_data(feature_dict, radius_cutoff=radius) return data # --------------------------------------------------------- # 3. INFERENCE ENDPOINT # --------------------------------------------------------- def predict_sequence(pdb_file): if pdb_file is None: return "Please upload a .pdb file." try: # Build the Heterogeneous Graph data = pdb_to_pyg_data(pdb_file.name, device=device) data = data.to(device) num_residues = data['protein'].x.shape[0] # Run the forward pass with torch.no_grad(): logits = model(data) # Decode logits to an amino acid string predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy() predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices]) return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}" except Exception as e: return f"Error processing PDB: {str(e)}" # --------------------------------------------------------- # 4. GRADIO UI # --------------------------------------------------------- demo = gr.Interface( fn=predict_sequence, inputs=gr.File(label="Upload Target Protein Backbone (.pdb)", file_types=[".pdb"]), outputs=gr.Textbox(label="Designed Amino Acid Sequence", lines=5), title="Struct2Seq-GNN: Inverse Protein Folding", description=( "Upload a 3D target backbone to generate a sequence optimized by a custom Heterogeneous Graph Neural Network.\n\n" "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** " "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints." ) ) if __name__ == "__main__": demo.launch()