WSobo commited on
Commit
f678ca5
·
verified ·
1 Parent(s): d0100b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -25
app.py CHANGED
@@ -1,41 +1,116 @@
 
 
1
  import gradio as gr
2
  import torch
3
- from huggingface_hub import hf_hub_download
4
- # import prody # (Uncomment when you add your parsing logic)
5
 
6
- # 1. Download the weights from your Model repository
7
- # Make sure "WSobo/Struct2Seq-GNN" matches exactly what you named your model repo
8
- model_path = hf_hub_download(repo_id="WSobo/Struct2Seq-GNN", filename="best_model.pt")
9
 
10
- # 2. Load the model (Ensure your model class is defined in this file or imported)
11
- # model = MyGNNModel()
12
- # model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
13
- # model.eval()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def predict_sequence(pdb_file):
16
- """
17
- This function is triggered every time a user hits 'Submit'.
18
- """
19
  if pdb_file is None:
20
  return "Please upload a .pdb file."
21
 
22
- # -- YOUR INFERENCE PIPELINE GOES HERE --
23
- # 1. Parse the uploaded pdb_file.name using ProDy
24
- # 2. Convert coordinates to tensors
25
- # 3. Run the forward pass: logits = model(coords)
26
- # 4. Decode logits to an amino acid string
27
-
28
- # Dummy return for testing the UI
29
- return "MKTLLILAVIM... (Your Struct2Seq-GNN prediction will appear here!)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # 3. Build the Gradio UI
 
 
32
  demo = gr.Interface(
33
  fn=predict_sequence,
34
- inputs=gr.File(label="Upload Protein Backbone (.pdb)", file_types=[".pdb"]),
35
- outputs=gr.Textbox(label="Predicted Amino Acid Sequence", show_copy_button=True),
36
  title="Struct2Seq-GNN: Inverse Protein Folding",
37
- description="Upload a 3D target backbone to generate a sequence optimized by a custom Graph Neural Network.",
38
- allow_flagging="never"
 
 
 
 
 
39
  )
40
 
41
  if __name__ == "__main__":
 
1
+ import os
2
+ import urllib.request
3
  import gradio as gr
4
  import torch
5
+ import prody
6
+ import numpy as np
7
 
8
+ # Import your model class (Make sure model_utils.py is uploaded to your Space!)
9
+ from model_utils import Struct2SeqGNN
 
10
 
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
12
 
13
+ # ---------------------------------------------------------
14
+ # 1. DOWNLOAD & LOAD MODEL WEIGHTS
15
+ # ---------------------------------------------------------
16
+ raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt"
17
+ model_path = "best_model.pt"
18
+
19
+ # Download weights if they aren't already cached in the Space
20
+ if not os.path.exists(model_path):
21
+ print("Downloading model weights from GitHub...")
22
+ urllib.request.urlretrieve(raw_github_url, model_path)
23
+
24
+ # Instantiate the model matching your v2.0 training parameters
25
+ model = Struct2SeqGNN(
26
+ node_features=6,
27
+ ligand_features=6,
28
+ hidden_dim=256,
29
+ num_classes=21,
30
+ num_layers=6,
31
+ dropout=0.0
32
+ ).to(device)
33
+
34
+ # Load the weights with DDP prefix handling
35
+ state_dict = torch.load(model_path, map_location=device)
36
+ if list(state_dict.keys())[0].startswith('module.'):
37
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
38
+ model.load_state_dict(state_dict)
39
+ model.eval()
40
+
41
+ # Standard Amino Acid alphabet (Update this if your model uses a different index order!)
42
+ AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"
43
+
44
+
45
+ # ---------------------------------------------------------
46
+ # 2. INFERENCE PIPELINE
47
+ # ---------------------------------------------------------
48
  def predict_sequence(pdb_file):
 
 
 
49
  if pdb_file is None:
50
  return "Please upload a .pdb file."
51
 
52
+ try:
53
+ # Step 2.1: Silence ProDy logs to keep your server console clean
54
+ prody.confProDy(verbosity='none')
55
+
56
+ # Step 2.2: Parse the uploaded PDB
57
+ pdb = prody.parsePDB(pdb_file.name)
58
+ if pdb is None:
59
+ return "Error: Could not parse the PDB file."
60
+
61
+ # Step 2.3: Extract backbone coordinates
62
+ # (Grabbing C-alphas to get the sequence length and main coordinates)
63
+ calphas = pdb.select('calpha')
64
+ if calphas is None:
65
+ return "Error: No alpha carbons found in the PDB."
66
+
67
+ num_residues = len(calphas)
68
+
69
+ # Step 2.4: Convert coordinates to numpy/tensors
70
+ # coords = calphas.getCoords() # Shape: [num_residues, 3]
71
+
72
+ # =====================================================================
73
+ # Step 2.5: YOUR GRAPH CONSTRUCTION GOES HERE
74
+ # Copy the exact logic you use in `utils.dataset.Struct2SeqDataset`
75
+ # to turn these coordinates into your graph components (x, edge_index).
76
+ # =====================================================================
77
+
78
+ # NOTE: Delete this placeholder dummy block once your logic is in!
79
+ dummy_logits = torch.randn((num_residues, 21)).to(device)
80
+
81
+ # Step 2.6: Run the forward pass
82
+ with torch.no_grad():
83
+ # logits = model(x, edge_index, ...)
84
+ logits = dummy_logits # Placeholder
85
+
86
+ # Step 2.7: Decode logits to an amino acid string
87
+ # Argmax gets the index of the highest probability AA for each residue
88
+ predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy()
89
+
90
+ # Map indices back to the alphabet characters
91
+ predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices])
92
+
93
+ return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}"
94
+
95
+ except Exception as e:
96
+ return f"Error processing PDB: {str(e)}"
97
+
98
 
99
+ # ---------------------------------------------------------
100
+ # 3. GRADIO UI
101
+ # ---------------------------------------------------------
102
  demo = gr.Interface(
103
  fn=predict_sequence,
104
+ inputs=gr.File(label="Upload Target Protein Backbone (.pdb)", file_types=[".pdb"]),
105
+ outputs=gr.Textbox(label="Designed Amino Acid Sequence", show_copy_button=True, lines=5),
106
  title="Struct2Seq-GNN: Inverse Protein Folding",
107
+ description=(
108
+ "Upload a 3D target backbone to generate a sequence optimized by a custom Graph Neural Network.\n\n"
109
+ "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** "
110
+ "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints."
111
+ ),
112
+ allow_flagging="never",
113
+ theme=gr.themes.Soft() # Adds a cleaner, more professional look
114
  )
115
 
116
  if __name__ == "__main__":