WSobo commited on
Commit
cd525f1
·
verified ·
1 Parent(s): dbc4151

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -42
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import os
2
  import urllib.request
 
3
  import gradio as gr
4
  import torch
5
- import prody
6
  import numpy as np
 
 
7
 
8
- # Import your model class (Make sure model_utils.py is uploaded to your Space!)
9
  from model_utils import Struct2SeqGNN
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -16,7 +18,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt"
17
  model_path = "best_model.pt"
18
 
19
- # Download weights if they aren't already cached in the Space
20
  if not os.path.exists(model_path):
21
  print("Downloading model weights from GitHub...")
22
  urllib.request.urlretrieve(raw_github_url, model_path)
@@ -38,56 +39,211 @@ if list(state_dict.keys())[0].startswith('module.'):
38
  model.load_state_dict(state_dict)
39
  model.eval()
40
 
41
- # Standard Amino Acid alphabet (Update this if your model uses a different index order!)
42
  AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
43
 
44
 
45
  # ---------------------------------------------------------
46
- # 2. INFERENCE PIPELINE
47
  # ---------------------------------------------------------
48
- def predict_sequence(pdb_file):
49
- if pdb_file is None:
50
- return "Please upload a .pdb file."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- try:
53
- # Step 2.1: Silence ProDy logs to keep your server console clean
54
- prody.confProDy(verbosity='none')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Step 2.2: Parse the uploaded PDB
57
- pdb = prody.parsePDB(pdb_file.name)
58
- if pdb is None:
59
- return "Error: Could not parse the PDB file."
60
-
61
- # Step 2.3: Extract backbone coordinates
62
- # (Grabbing C-alphas to get the sequence length and main coordinates)
63
- calphas = pdb.select('calpha')
64
- if calphas is None:
65
- return "Error: No alpha carbons found in the PDB."
66
 
67
- num_residues = len(calphas)
 
68
 
69
- # Step 2.4: Convert coordinates to numpy/tensors
70
- # coords = calphas.getCoords() # Shape: [num_residues, 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # =====================================================================
73
- # Step 2.5: YOUR GRAPH CONSTRUCTION GOES HERE
74
- # Copy the exact logic you use in `utils.dataset.Struct2SeqDataset`
75
- # to turn these coordinates into your graph components (x, edge_index).
76
- # =====================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # NOTE: Delete this placeholder dummy block once your logic is in!
79
- dummy_logits = torch.randn((num_residues, 21)).to(device)
80
 
81
- # Step 2.6: Run the forward pass
82
  with torch.no_grad():
83
- # logits = model(x, edge_index, ...)
84
- logits = dummy_logits # Placeholder
85
 
86
- # Step 2.7: Decode logits to an amino acid string
87
- # Argmax gets the index of the highest probability AA for each residue
88
  predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy()
89
-
90
- # Map indices back to the alphabet characters
91
  predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices])
92
 
93
  return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}"
@@ -95,9 +251,8 @@ def predict_sequence(pdb_file):
95
  except Exception as e:
96
  return f"Error processing PDB: {str(e)}"
97
 
98
-
99
  # ---------------------------------------------------------
100
- # 3. GRADIO UI
101
  # ---------------------------------------------------------
102
  demo = gr.Interface(
103
  fn=predict_sequence,
@@ -105,12 +260,12 @@ demo = gr.Interface(
105
  outputs=gr.Textbox(label="Designed Amino Acid Sequence", show_copy_button=True, lines=5),
106
  title="Struct2Seq-GNN: Inverse Protein Folding",
107
  description=(
108
- "Upload a 3D target backbone to generate a sequence optimized by a custom Graph Neural Network.\n\n"
109
  "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** "
110
  "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints."
111
  ),
112
  allow_flagging="never",
113
- theme=gr.themes.Soft() # Adds a cleaner, more professional look
114
  )
115
 
116
  if __name__ == "__main__":
 
1
  import os
2
  import urllib.request
3
+ import importlib.util
4
  import gradio as gr
5
  import torch
 
6
  import numpy as np
7
+ from torch_geometric.data import HeteroData
8
+ from torch_geometric.nn import radius_graph, radius
9
 
10
+ # Import your model class (Make sure model_utils.py is in your Space!)
11
  from model_utils import Struct2SeqGNN
12
 
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
18
  raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt"
19
  model_path = "best_model.pt"
20
 
 
21
  if not os.path.exists(model_path):
22
  print("Downloading model weights from GitHub...")
23
  urllib.request.urlretrieve(raw_github_url, model_path)
 
39
  model.load_state_dict(state_dict)
40
  model.eval()
41
 
42
+ # Standard Amino Acid alphabet
43
  AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
44
 
45
 
46
  # ---------------------------------------------------------
47
+ # 2. DATA PROCESSING PIPELINE (PyG HeteroData)
48
  # ---------------------------------------------------------
49
+ def _load_ligandmpnn_parsers():
50
+ """Load LigandMPNN parser functions directly from the HF Space root."""
51
+ parser_file = "data_utils.py"
52
+ if not os.path.exists(parser_file):
53
+ raise ImportError(
54
+ "Could not find data_utils.py. "
55
+ "Please upload the LigandMPNN data_utils.py file to the root of your Hugging Face Space."
56
+ )
57
+
58
+ spec = importlib.util.spec_from_file_location("ligandmpnn_data_utils", parser_file)
59
+ module = importlib.util.module_from_spec(spec)
60
+ assert spec.loader is not None
61
+ spec.loader.exec_module(module)
62
+ return module.parse_PDB, module.featurize
63
+
64
+ parse_PDB, featurize = _load_ligandmpnn_parsers()
65
+
66
+ def get_ligandmpnn_features(pdb_path, device="cpu"):
67
+ protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(pdb_path, device=device)
68
 
69
+ if "chain_letters" in protein_dict:
70
+ protein_dict["chain_mask"] = torch.ones(
71
+ len(protein_dict["chain_letters"]),
72
+ dtype=torch.int32,
73
+ device=device
74
+ )
75
+
76
+ feature_dict = featurize(protein_dict, cutoff_for_score=8.0)
77
+
78
+ feature_dict["ligand_Y"] = protein_dict.get("Y", None)
79
+ feature_dict["ligand_Y_t"] = protein_dict.get("Y_t", None)
80
+ feature_dict["ligand_Y_m"] = protein_dict.get("Y_m", None)
81
+
82
+ return feature_dict
83
+
84
+ def compute_dihedrals(X):
85
+ N = X[:, 0, :]
86
+ CA = X[:, 1, :]
87
+ C = X[:, 2, :]
88
+
89
+ C_prev = torch.cat([C[0:1], C[:-1]], dim=0)
90
+ N_next = torch.cat([N[1:], N[-1:]], dim=0)
91
+ CA_next = torch.cat([CA[1:], CA[-1:]], dim=0)
92
+
93
+ def dihedral(p0, p1, p2, p3):
94
+ b0 = p0 - p1
95
+ b1 = p2 - p1
96
+ b2 = p3 - p2
97
 
98
+ b1_norm = b1 / (torch.linalg.norm(b1, dim=-1, keepdim=True) + 1e-7)
99
+
100
+ n1 = torch.linalg.cross(b0, b1_norm, dim=-1)
101
+ n2 = torch.linalg.cross(b1_norm, b2, dim=-1)
102
+ m = torch.linalg.cross(n1, b1_norm, dim=-1)
 
 
 
 
 
103
 
104
+ x = torch.sum(n1 * n2, dim=-1)
105
+ y = torch.sum(m * n2, dim=-1)
106
 
107
+ return torch.atan2(y, x)
108
+
109
+ phi = dihedral(C_prev, N, CA, C)
110
+ psi = dihedral(N, CA, C, N_next)
111
+ omega = dihedral(CA, C, N_next, CA_next)
112
+
113
+ dihedrals = torch.stack([phi, psi, omega], dim=-1)
114
+ return torch.cat([torch.sin(dihedrals), torch.cos(dihedrals)], dim=-1)
115
+
116
+ def encode_ligand_elements(element_ids):
117
+ M = element_ids.shape[0]
118
+ one_hot = torch.zeros((M, 6), dtype=torch.float32, device=element_ids.device)
119
+
120
+ mask_C = (element_ids == 6)
121
+ mask_N = (element_ids == 7)
122
+ mask_O = (element_ids == 8)
123
+ mask_S = (element_ids == 16)
124
+ mask_P = (element_ids == 15)
125
+
126
+ one_hot[mask_C, 0] = 1.0
127
+ one_hot[mask_N, 1] = 1.0
128
+ one_hot[mask_O, 2] = 1.0
129
+ one_hot[mask_S, 3] = 1.0
130
+ one_hot[mask_P, 4] = 1.0
131
+
132
+ mask_other = ~(mask_C | mask_N | mask_O | mask_S | mask_P)
133
+ one_hot[mask_other, 5] = 1.0
134
+
135
+ return one_hot
136
+
137
+ def dict_to_pyg_data(feature_dict, radius_cutoff=8.0):
138
+ data = HeteroData()
139
+
140
+ # 1. Build Protein Nodes
141
+ X = feature_dict["X"].squeeze(0)
142
+ if X.dim() == 3 and X.size(1) >= 4:
143
+ ca_coords = X[:, 1, :]
144
+ else:
145
+ ca_coords = X
146
+
147
+ sequence_labels = feature_dict["S"].squeeze(0)
148
+ mask = feature_dict["mask"].squeeze(0).bool()
149
+
150
+ dihedral_features = compute_dihedrals(X)
151
+
152
+ ca_coords = ca_coords[mask]
153
+ sequence_labels = sequence_labels[mask]
154
+ dihedral_features = dihedral_features[mask]
155
+
156
+ data['protein'].x = dihedral_features.clone().float()
157
+ data['protein'].pos = ca_coords.clone().float()
158
+ data['protein'].y = sequence_labels.long()
159
+
160
+ if "chain_M" in feature_dict:
161
+ data['protein'].chain_M = feature_dict["chain_M"].squeeze(0)[mask]
162
+
163
+ p_pos = data['protein'].pos
164
+ pp_edge_index = radius_graph(p_pos, r=radius_cutoff, loop=False)
165
+ p_row, p_col = pp_edge_index
166
+ pp_dist = torch.norm(p_pos[p_row] - p_pos[p_col], dim=1, p=2).unsqueeze(-1)
167
+
168
+ data['protein', 'interacts_with', 'protein'].edge_index = pp_edge_index
169
+ data['protein', 'interacts_with', 'protein'].edge_attr = pp_dist
170
+
171
+ # 2. Build Ligand Nodes
172
+ Y = feature_dict.get("ligand_Y")
173
+ Y_t = feature_dict.get("ligand_Y_t")
174
+ Y_m = feature_dict.get("ligand_Y_m")
175
+
176
+ num_ligand_atoms = 0
177
+ if Y is not None and Y_m is not None:
178
+ Y_mask = Y_m.bool()
179
+ if Y_mask.sum() > 0:
180
+ Y = Y[Y_mask]
181
+ Y_t = Y_t[Y_mask]
182
+ num_ligand_atoms = Y.shape[0]
183
+
184
+ lig_x = encode_ligand_elements(Y_t)
185
+ data['ligand'].x = lig_x
186
+ data['ligand'].pos = Y.float()
187
+
188
+ if num_ligand_atoms > 0:
189
+ l_pos = data['ligand'].pos
190
+ pl_edge_index = radius(l_pos, p_pos, r=radius_cutoff)
191
 
192
+ if pl_edge_index.size(1) > 0:
193
+ p_idx, l_idx = pl_edge_index[0], pl_edge_index[1]
194
+
195
+ lp_edge_index = torch.stack([l_idx, p_idx], dim=0)
196
+ lp_dist = torch.norm(l_pos[l_idx] - p_pos[p_idx], dim=1, p=2).unsqueeze(-1)
197
+
198
+ data['ligand', 'binds', 'protein'].edge_index = lp_edge_index
199
+ data['ligand', 'binds', 'protein'].edge_attr = lp_dist
200
+
201
+ pl_edge_index_rev = torch.stack([p_idx, l_idx], dim=0)
202
+ data['protein', 'binds', 'ligand'].edge_index = pl_edge_index_rev
203
+ data['protein', 'binds', 'ligand'].edge_attr = lp_dist.clone()
204
+ else:
205
+ data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
206
+ data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
207
+ data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
208
+ data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
209
+
210
+ else:
211
+ data['ligand'].x = torch.empty((0, 6), dtype=torch.float32)
212
+ data['ligand'].pos = torch.empty((0, 3), dtype=torch.float32)
213
+ data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
214
+ data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
215
+ data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
216
+ data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
217
+
218
+ return data
219
+
220
+ def pdb_to_pyg_data(pdb_path, radius=8.0, device="cpu"):
221
+ feature_dict = get_ligandmpnn_features(pdb_path, device=device)
222
+ data = dict_to_pyg_data(feature_dict, radius_cutoff=radius)
223
+ return data
224
+
225
+
226
+ # ---------------------------------------------------------
227
+ # 3. INFERENCE ENDPOINT
228
+ # ---------------------------------------------------------
229
+ def predict_sequence(pdb_file):
230
+ if pdb_file is None:
231
+ return "Please upload a .pdb file."
232
+
233
+ try:
234
+ # Build the Heterogeneous Graph
235
+ data = pdb_to_pyg_data(pdb_file.name, device=device)
236
+ data = data.to(device)
237
 
238
+ num_residues = data['protein'].x.shape[0]
 
239
 
240
+ # Run the forward pass
241
  with torch.no_grad():
242
+ # Adjust the arguments here if your forward method signature differs!
243
+ logits = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)
244
 
245
+ # Decode logits to an amino acid string
 
246
  predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy()
 
 
247
  predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices])
248
 
249
  return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}"
 
251
  except Exception as e:
252
  return f"Error processing PDB: {str(e)}"
253
 
 
254
  # ---------------------------------------------------------
255
+ # 4. GRADIO UI
256
  # ---------------------------------------------------------
257
  demo = gr.Interface(
258
  fn=predict_sequence,
 
260
  outputs=gr.Textbox(label="Designed Amino Acid Sequence", show_copy_button=True, lines=5),
261
  title="Struct2Seq-GNN: Inverse Protein Folding",
262
  description=(
263
+ "Upload a 3D target backbone to generate a sequence optimized by a custom Heterogeneous Graph Neural Network.\n\n"
264
  "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** "
265
  "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints."
266
  ),
267
  allow_flagging="never",
268
+ theme=gr.themes.Soft()
269
  )
270
 
271
  if __name__ == "__main__":