File size: 9,889 Bytes
f678ca5
 
cd525f1
d0100b9
 
f678ca5
cd525f1
 
d0100b9
cd525f1
f678ca5
d0100b9
f678ca5
d0100b9
f678ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd525f1
f678ca5
 
 
 
cd525f1
f678ca5
cd525f1
 
489d5d8
cd525f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0100b9
cd525f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f678ca5
cd525f1
 
 
 
 
f678ca5
cd525f1
 
f678ca5
cd525f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f678ca5
cd525f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f678ca5
cd525f1
f678ca5
cd525f1
f678ca5
060ccef
f678ca5
cd525f1
f678ca5
 
 
 
 
 
 
 
 
cd525f1
f678ca5
d0100b9
 
f678ca5
85fee63
d0100b9
f678ca5
cd525f1
f678ca5
 
85fee63
d0100b9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
import urllib.request
import importlib.util
import gradio as gr
import torch
import numpy as np
from torch_geometric.data import HeteroData
from torch_geometric.nn import radius_graph, radius

# Import your model class (Make sure model_utils.py is in your Space!)
from model_utils import Struct2SeqGNN 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---------------------------------------------------------
# 1. DOWNLOAD & LOAD MODEL WEIGHTS
# ---------------------------------------------------------
raw_github_url = "https://raw.githubusercontent.com/WSobo/Struct2Seq-GNN/main/pretrained_models/v2.0/best_model.pt"
model_path = "best_model.pt"

if not os.path.exists(model_path):
    print("Downloading model weights from GitHub...")
    urllib.request.urlretrieve(raw_github_url, model_path)

# Instantiate the model matching your v2.0 training parameters
model = Struct2SeqGNN(
    node_features=6, 
    ligand_features=6, 
    hidden_dim=256, 
    num_classes=21, 
    num_layers=6, 
    dropout=0.0
).to(device)

# Load the weights with DDP prefix handling
state_dict = torch.load(model_path, map_location=device)
if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

# Standard Amino Acid alphabet 
AA_ALPHABET = "ACDEFGHIKLMNPQRSTVWYX"


# ---------------------------------------------------------
# 2. DATA PROCESSING PIPELINE (PyG HeteroData)
# ---------------------------------------------------------
def _load_ligandmpnn_parsers():
    """Load LigandMPNN parser functions directly from the HF Space root."""
    parser_file = "LigandMPNN_data_utils.py"
    if not os.path.exists(parser_file):
        raise ImportError(
            "Could not find data_utils.py. "
            "Please upload the LigandMPNN data_utils.py file to the root of your Hugging Face Space."
        )

    spec = importlib.util.spec_from_file_location("ligandmpnn_data_utils", parser_file)
    module = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(module)
    return module.parse_PDB, module.featurize

parse_PDB, featurize = _load_ligandmpnn_parsers()

def get_ligandmpnn_features(pdb_path, device="cpu"):
    protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(pdb_path, device=device)
    
    if "chain_letters" in protein_dict:
        protein_dict["chain_mask"] = torch.ones(
            len(protein_dict["chain_letters"]), 
            dtype=torch.int32, 
            device=device
        )

    feature_dict = featurize(protein_dict, cutoff_for_score=8.0)
    
    feature_dict["ligand_Y"] = protein_dict.get("Y", None)
    feature_dict["ligand_Y_t"] = protein_dict.get("Y_t", None)
    feature_dict["ligand_Y_m"] = protein_dict.get("Y_m", None)
    
    return feature_dict

def compute_dihedrals(X):
    N = X[:, 0, :]
    CA = X[:, 1, :]
    C = X[:, 2, :]
    
    C_prev = torch.cat([C[0:1], C[:-1]], dim=0)
    N_next = torch.cat([N[1:], N[-1:]], dim=0)
    CA_next = torch.cat([CA[1:], CA[-1:]], dim=0)
    
    def dihedral(p0, p1, p2, p3):
        b0 = p0 - p1
        b1 = p2 - p1
        b2 = p3 - p2
        
        b1_norm = b1 / (torch.linalg.norm(b1, dim=-1, keepdim=True) + 1e-7)
        
        n1 = torch.linalg.cross(b0, b1_norm, dim=-1)
        n2 = torch.linalg.cross(b1_norm, b2, dim=-1)
        m = torch.linalg.cross(n1, b1_norm, dim=-1)
        
        x = torch.sum(n1 * n2, dim=-1)
        y = torch.sum(m * n2, dim=-1)
        
        return torch.atan2(y, x)
        
    phi = dihedral(C_prev, N, CA, C)
    psi = dihedral(N, CA, C, N_next)
    omega = dihedral(CA, C, N_next, CA_next)
    
    dihedrals = torch.stack([phi, psi, omega], dim=-1)
    return torch.cat([torch.sin(dihedrals), torch.cos(dihedrals)], dim=-1)

def encode_ligand_elements(element_ids):
    M = element_ids.shape[0]
    one_hot = torch.zeros((M, 6), dtype=torch.float32, device=element_ids.device)
    
    mask_C = (element_ids == 6)
    mask_N = (element_ids == 7)
    mask_O = (element_ids == 8)
    mask_S = (element_ids == 16)
    mask_P = (element_ids == 15)
    
    one_hot[mask_C, 0] = 1.0
    one_hot[mask_N, 1] = 1.0
    one_hot[mask_O, 2] = 1.0
    one_hot[mask_S, 3] = 1.0
    one_hot[mask_P, 4] = 1.0
    
    mask_other = ~(mask_C | mask_N | mask_O | mask_S | mask_P)
    one_hot[mask_other, 5] = 1.0
    
    return one_hot

def dict_to_pyg_data(feature_dict, radius_cutoff=8.0):
    data = HeteroData()

    # 1. Build Protein Nodes
    X = feature_dict["X"].squeeze(0) 
    if X.dim() == 3 and X.size(1) >= 4:
        ca_coords = X[:, 1, :]
    else:
        ca_coords = X
        
    sequence_labels = feature_dict["S"].squeeze(0)
    mask = feature_dict["mask"].squeeze(0).bool()
    
    dihedral_features = compute_dihedrals(X)
    
    ca_coords = ca_coords[mask]
    sequence_labels = sequence_labels[mask]
    dihedral_features = dihedral_features[mask]
    
    data['protein'].x = dihedral_features.clone().float()
    data['protein'].pos = ca_coords.clone().float()
    data['protein'].y = sequence_labels.long()
    
    if "chain_M" in feature_dict:
        data['protein'].chain_M = feature_dict["chain_M"].squeeze(0)[mask]
    
    p_pos = data['protein'].pos
    pp_edge_index = radius_graph(p_pos, r=radius_cutoff, loop=False)
    p_row, p_col = pp_edge_index
    pp_dist = torch.norm(p_pos[p_row] - p_pos[p_col], dim=1, p=2).unsqueeze(-1)
    
    data['protein', 'interacts_with', 'protein'].edge_index = pp_edge_index
    data['protein', 'interacts_with', 'protein'].edge_attr = pp_dist

    # 2. Build Ligand Nodes
    Y = feature_dict.get("ligand_Y") 
    Y_t = feature_dict.get("ligand_Y_t")
    Y_m = feature_dict.get("ligand_Y_m")
    
    num_ligand_atoms = 0
    if Y is not None and Y_m is not None:
        Y_mask = Y_m.bool()
        if Y_mask.sum() > 0:
            Y = Y[Y_mask]
            Y_t = Y_t[Y_mask]
            num_ligand_atoms = Y.shape[0]
            
            lig_x = encode_ligand_elements(Y_t)
            data['ligand'].x = lig_x
            data['ligand'].pos = Y.float()
            
    if num_ligand_atoms > 0:
        l_pos = data['ligand'].pos
        pl_edge_index = radius(l_pos, p_pos, r=radius_cutoff)
        
        if pl_edge_index.size(1) > 0:
            p_idx, l_idx = pl_edge_index[0], pl_edge_index[1]
            
            lp_edge_index = torch.stack([l_idx, p_idx], dim=0)
            lp_dist = torch.norm(l_pos[l_idx] - p_pos[p_idx], dim=1, p=2).unsqueeze(-1)
            
            data['ligand', 'binds', 'protein'].edge_index = lp_edge_index
            data['ligand', 'binds', 'protein'].edge_attr = lp_dist
            
            pl_edge_index_rev = torch.stack([p_idx, l_idx], dim=0)
            data['protein', 'binds', 'ligand'].edge_index = pl_edge_index_rev
            data['protein', 'binds', 'ligand'].edge_attr = lp_dist.clone()
        else:
            data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
            data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
            data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
            data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
            
    else:
        data['ligand'].x = torch.empty((0, 6), dtype=torch.float32)
        data['ligand'].pos = torch.empty((0, 3), dtype=torch.float32)
        data['ligand', 'binds', 'protein'].edge_index = torch.empty((2, 0), dtype=torch.long)
        data['ligand', 'binds', 'protein'].edge_attr = torch.empty((0, 1), dtype=torch.float32)
        data['protein', 'binds', 'ligand'].edge_index = torch.empty((2, 0), dtype=torch.long)
        data['protein', 'binds', 'ligand'].edge_attr = torch.empty((0, 1), dtype=torch.float32)

    return data

def pdb_to_pyg_data(pdb_path, radius=8.0, device="cpu"):
    feature_dict = get_ligandmpnn_features(pdb_path, device=device)
    data = dict_to_pyg_data(feature_dict, radius_cutoff=radius)
    return data


# ---------------------------------------------------------
# 3. INFERENCE ENDPOINT
# ---------------------------------------------------------
def predict_sequence(pdb_file):
    if pdb_file is None:
        return "Please upload a .pdb file."
    
    try:
        # Build the Heterogeneous Graph
        data = pdb_to_pyg_data(pdb_file.name, device=device)
        data = data.to(device)
        
        num_residues = data['protein'].x.shape[0]
        
        # Run the forward pass
        with torch.no_grad():
            logits = model(data)
            
        # Decode logits to an amino acid string
        predicted_indices = torch.argmax(logits, dim=-1).cpu().numpy()
        predicted_seq = "".join([AA_ALPHABET[idx] for idx in predicted_indices])
        
        return f"Predicted Sequence ({num_residues} residues):\n\n{predicted_seq}"
        
    except Exception as e:
        return f"Error processing PDB: {str(e)}"

# ---------------------------------------------------------
# 4. GRADIO UI
# ---------------------------------------------------------
demo = gr.Interface(
    fn=predict_sequence,
    inputs=gr.File(label="Upload Target Protein Backbone (.pdb)", file_types=[".pdb"]),
    outputs=gr.Textbox(label="Designed Amino Acid Sequence", lines=5),
    title="Struct2Seq-GNN: Inverse Protein Folding",
    description=(
        "Upload a 3D target backbone to generate a sequence optimized by a custom Heterogeneous Graph Neural Network.\n\n"
        "**Model Performance:** Achieves ~30.3% global sequence recovery and **35.1% binding-pocket recovery** "
        "on noisy coordinates, confirming strong generalization to underlying biophysical folding constraints."
    )
)

if __name__ == "__main__":
    demo.launch()