Spaces:
Runtime error
Runtime error
| import os | |
| import urllib.request | |
| import importlib.util | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from torch_geometric.data import HeteroData | |
| from torch_geometric.nn import radius_graph, radius | |
| # Import your model class (Make sure model_utils.py is in your Space!) | |
| from model_utils import Struct2SeqGNN | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # --------------------------------------------------------- | |
| # 1. DOWNLOAD & LOAD MODEL WEIGHTS | |
| # --------------------------------------------------------- | |
| raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt" | |
| model_path = "best_model.pt" | |
| if not os.path.exists(model_path): | |
| print("Downloading model weights from GitHub...") | |
| urllib.request.urlretrieve(raw_github_url, model_path) | |
| # Instantiate the model matching your v2.0 training parameters | |
| model = Struct2SeqGNN( | |
| node_features=6, | |
| ligand_features=6, | |
| hidden_dim=256, | |
| num_classes=21, | |
| num_layers=6, | |
| dropout=0.0 | |
| ).to(device) | |
| # Load the weights with DDP prefix handling | |
| state_dict = torch.load(model_path, map_location=device) | |
| if list(state_dict.keys())[0].startswith('module.'): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Standard Amino Acid alphabet | |
| AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX" | |
| # --------------------------------------------------------- | |
| # 2. DATA PROCESSING PIPELINE (PyG HeteroData) | |
| # --------------------------------------------------------- | |
| def _load_ligandmpnn_parsers(): | |
| """Load LigandMPNN parser functions directly from the HF Space root.""" | |
| parser_file = "LigandMPNN_data_utils.py" | |
| if not os.path.exists(parser_file): | |
| raise ImportError( | |
| "Could not find data_utils.py. " | |
| "Please upload the LigandMPNN data_utils.py file to the root of your Hugging Face Space." | |
| ) | |
| spec = importlib.util.spec_from_file_location("ligandmpnn_data_utils", parser_file) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| return module.parse_PDB, module.featurize | |
| parse_PDB, featurize = _load_ligandmpnn_parsers() | |
| def get_ligandmpnn_features(pdb_path, device="cpu"): | |
| protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(pdb_path, device=device) | |
| if "chain_letters" in protein_dict: | |
| protein_dict["chain_mask"] = torch.ones( | |
| len(protein_dict["chain_letters"]), | |
| dtype=torch.int32, | |
| device=device | |
| ) | |
| feature_dict = featurize(protein_dict, cutoff_for_score=8.0) | |
| feature_dict["ligand_Y"] = protein_dict.get("Y", None) | |
| feature_dict["ligand_Y_t"] = protein_dict.get("Y_t", None) | |
| feature_dict["ligand_Y_m"] = protein_dict.get("Y_m", None) | |
| return feature_dict | |
| def compute_dihedrals(X): | |
| N = X[:, 0, :] | |
| CA = X[:, 1, :] | |
| C = X[:, 2, :] | |
| C_prev = torch.cat([C[0:1], C[:-1]], dim=0) | |
| N_next = torch.cat([N[1:], N[-1:]], dim=0) | |
| CA_next = torch.cat([CA[1:], CA[-1:]], dim=0) | |
| def dihedral(p0, p1, p2, p3): | |
| b0 = p0 - p1 | |
| b1 = p2 - p1 | |
| b2 = p3 - p2 | |
| b1_norm = b1 / (torch.linalg.norm(b1, dim=-1, keepdim=True) + 1e-7) | |
| n1 = torch.linalg.cross(b0, b1_norm, dim=-1) | |
| n2 = torch.linalg.cross(b1_norm, b2, dim=-1) | |
| m = torch.linalg.cross(n1, b1_norm, dim=-1) | |
| x = torch.sum(n1 * n2, dim=-1) | |
| y = torch.sum(m * n2, dim=-1) | |
| return torch.atan2(y, x) | |
| phi = dihedral(C_prev, N, CA, C) | |
| psi = dihedral(N, CA, C, N_next) | |
| omega = dihedral(CA, C, N_next, CA_next) | |
| dihedrals = torch.stack([phi, psi, omega], dim=-1) | |
| return torch.cat([torch.sin(dihedrals), torch.cos(dihedrals)], dim=-1) | |
| def encode_ligand_elements(element_ids): | |
| M = element_ids.shape[0] | |
| one_hot = torch.zeros((M, 6), dtype=torch.float32, device=element_ids.device) | |
| mask_C = (element_ids == 6) | |
| mask_N = (element_ids == 7) | |
| mask_O = (element_ids == 8) | |
| mask_S = (element_ids == 16) | |
| mask_P = (element_ids == 15) | |
| one_hot[mask_C, 0] = 1.0 | |
| one_hot[mask_N, 1] = 1.0 | |
| one_hot[mask_O, 2] = 1.0 | |
| one_hot[mask_S, 3] = 1.0 | |
| one_hot[mask_P, 4] = 1.0 | |
| mask_other = ~(mask_C | mask_N | mask_O | mask_S | mask_P) | |
| one_hot[mask_other, 5] = 1.0 | |
| return one_hot | |
| def dict_to_pyg_data(feature_dict, radius_cutoff=8.0): | |
| data = HeteroData() | |
| # 1. Build Protein Nodes | |
| X = feature_dict["X"].squeeze(0) | |
| if X.dim() == 3 and X.size(1) >= 4: | |
| ca_coords = X[:, 1, :] | |
| else: | |
| ca_coords = X | |
| sequence_labels = feature_dict["S"].squeeze(0) | |
| mask = feature_dict["mask"].squeeze(0).bool() | |
| dihedral_features = compute_dihedrals(X) | |
| ca_coords = ca_coords[mask] | |
| sequence_labels = sequence_labels[mask] | |
| dihedral_features = dihedral_features[mask] | |
| data['protein'].x = dihedral_features.clone().float() | |
| data['protein'].pos = ca_coords.clone().float() | |
| data['protein'].y = sequence_labels.long() | |
| if "chain_M" in feature_dict: | |
| data['protein'].chain_M = feature_dict["chain_M"].squeeze(0)[mask] | |
| p_pos = data['protein'].pos | |
| pp_edge_index = radius_graph(p_pos, r=radius_cutoff, loop=False) | |
| p_row, p_col = pp_edge_index | |
| pp_dist = torch.norm(p_pos[p_row] - p_pos[p_col], dim=1, p=2).unsqueeze(-1) | |
| data['protein', 'interacts_with', 'protein'].edge_index = pp_edge_index | |
| data['protein', 'interacts_with', 'protein'].edge_attr = pp_dist | |
| # 2. Build Ligand Nodes | |
| Y = feature_dict.get("ligand_Y") | |
| Y_t = feature_dict.get("ligand_Y_t") | |
| Y_m = feature_dict.get("ligand_Y_m") | |
| num_ligand_atoms = 0 | |
| if Y is not None and Y_m is not None: | |
| Y_mask = Y_m.bool() | |
| if Y_mask.sum() > 0: | |
| Y = Y[Y_mask] | |
| Y_t = Y_t[Y_mask] | |
| num_ligand_atoms = Y.shape[0] | |
| lig_x = encode_ligand_elements(Y_t) | |
| data['ligand'].x = lig_x | |
| data['ligand'].pos = Y.float() | |
| if num_ligand_atoms > 0: | |
| l_pos = data['ligand'].pos | |
| pl_edge_index = radius(l_pos, p_pos, r=radius_cutoff) | |
| if pl_edge_index.size(1) > 0: | |
| p_idx, l_idx = pl_edge_index[0], pl_edge_index[1] | |
| lp_edge_index = torch.stack([l_idx, p_idx], dim=0) | |
| lp_dist = torch.norm(l_pos[l_idx] - p_pos[p_idx], dim=1, p=2).unsqueeze(-1) | |
| data['ligand', 'binds', 'protein'].edge_index = lp_edge_index | |
| data['ligand', 'binds', 'protein'].edge_attr = lp_dist | |
| pl_edge_index_rev = torch.stack([p_idx, l_idx], dim=0) | |
| data['protein', 'binds', 'ligand'].edge_index = pl_edge_index_rev | |
| data['protein', 'binds', 'ligand'].edge_attr = lp_dist.clone() | |
| else: | |
| data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long) | |
| data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32) | |
| data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long) | |
| data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32) | |
| else: | |
| data['ligand'].x = torch.empty((0, 6), dtype=torch.float32) | |
| data['ligand'].pos = torch.empty((0, 3), dtype=torch.float32) | |
| data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long) | |
| data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32) | |
| data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long) | |
| data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32) | |
| return data | |
| def pdb_to_pyg_data(pdb_path, radius=8.0, device="cpu"): | |
| feature_dict = get_ligandmpnn_features(pdb_path, device=device) | |
| data = dict_to_pyg_data(feature_dict, radius_cutoff=radius) | |
| return data | |
| # --------------------------------------------------------- | |
| # 3. INFERENCE ENDPOINT | |
| # --------------------------------------------------------- | |
| def predict_sequence(pdb_file): | |
| if pdb_file is None: | |
| return "Please upload a .pdb file." | |
| try: | |
| # Build the Heterogeneous Graph | |
| data = pdb_to_pyg_data(pdb_file.name, device=device) | |
| data = data.to(device) | |
| num_residues = data['protein'].x.shape[0] | |
| # Run the forward pass | |
| with torch.no_grad(): | |
| logits = model(data) | |
| # Decode logits to an amino acid string | |
| predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy() | |
| predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices]) | |
| return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}" | |
| except Exception as e: | |
| return f"Error processing PDB: {str(e)}" | |
| # --------------------------------------------------------- | |
| # 4. GRADIO UI | |
| # --------------------------------------------------------- | |
| demo = gr.Interface( | |
| fn=predict_sequence, | |
| inputs=gr.File(label="Upload Target Protein Backbone (.pdb)", file_types=[".pdb"]), | |
| outputs=gr.Textbox(label="Designed Amino Acid Sequence", lines=5), | |
| title="Struct2Seq-GNN: Inverse Protein Folding", | |
| description=( | |
| "Upload a 3D target backbone to generate a sequence optimized by a custom Heterogeneous Graph Neural Network.\n\n" | |
| "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** " | |
| "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints." | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |