WSobo commited on
Commit
dbc4151
·
verified ·
1 Parent(s): f678ca5

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +128 -0
model_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import TransformerConv, Linear, HeteroConv
4
+ import torch.nn.functional as F
5
+
6
+ class GaussianSmearing(nn.Module):
7
+ """
8
+ Expands a scalar distance tensor into a 16-dimensional Radial Basis Function (RBF) vector.
9
+ This creates a rich representation for neural networks out of a single scalar distance.
10
+ """
11
+ def __init__(self, start=0.0, stop=8.0, num_gaussians=16):
12
+ super(GaussianSmearing, self).__init__()
13
+ offset = torch.linspace(start, stop, num_gaussians)
14
+ # Calculate the Gaussian coefficient directly handling width scale
15
+ self.coeff = -0.5 / ((stop - start) / (num_gaussians - 1))**2
16
+ self.register_buffer('offset', offset)
17
+
18
+ def forward(self, dist):
19
+ # dist shape: [E, 1]
20
+ # output shape: [E, 16]
21
+ dist = dist.view(-1, 1) - self.offset.view(1, -1)
22
+ return torch.exp(self.coeff * torch.pow(dist, 2))
23
+
24
+ class ResidualAttentionBlock(nn.Module):
25
+ def __init__(self, hidden_dim, edge_dim=19, dropout=0.1):
26
+ super(ResidualAttentionBlock, self).__init__()
27
+ # TransformerConv natively absorbs edge_attr into its attention mechanism.
28
+ # We split the 128 hidden_dim across 4 attention heads (32 dim per head) to maintain exact tensor shapes.
29
+ self.conv = TransformerConv(hidden_dim, hidden_dim // 4, heads=4, concat=True, edge_dim=edge_dim)
30
+ self.norm = nn.LayerNorm(hidden_dim)
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, x, edge_index, edge_attr):
34
+ # Handle bipartite graphs (e.g. ['ligand', 'binds', 'protein'])
35
+ if isinstance(x, tuple):
36
+ x_src, x_dst = x
37
+ identity = x_dst
38
+ else:
39
+ identity = x
40
+
41
+ out = self.conv(x, edge_index, edge_attr)
42
+ out = self.norm(out)
43
+ out = F.relu(out)
44
+ out = self.dropout(out)
45
+ return out + identity
46
+
47
+ class Struct2SeqGNN(nn.Module):
48
+ def __init__(self, node_features=6, ligand_features=6, hidden_dim=128, num_classes=21, num_layers=4, dropout=0.1):
49
+ super(Struct2SeqGNN, self).__init__()
50
+
51
+ # Initial node embeddings for distinct node types
52
+ self.protein_emb = Linear(node_features, hidden_dim)
53
+ self.ligand_emb = Linear(ligand_features, hidden_dim)
54
+
55
+ # Distinct RBF distance expansions for each edge type
56
+ # Decoupling these allows the network to scale bonds differently (e.g. covalent backbone vs weak ionic ligand bonds)
57
+ self.edge_embs = nn.ModuleDict({
58
+ 'protein__interacts_with__protein': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16),
59
+ 'ligand__binds__protein': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16),
60
+ 'protein__binds__ligand': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16)
61
+ })
62
+
63
+ # Deep module list involving HeteroConv
64
+ self.layers = nn.ModuleList()
65
+ for _ in range(num_layers):
66
+ conv = HeteroConv({
67
+ ('protein', 'interacts_with', 'protein'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout),
68
+ ('ligand', 'binds', 'protein'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout),
69
+ ('protein', 'binds', 'ligand'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout)
70
+ }, aggr='sum')
71
+ self.layers.append(conv)
72
+
73
+ # Final layer normalization before classification (only applied to protein output)
74
+ self.norm_out = nn.LayerNorm(hidden_dim)
75
+
76
+ # Sequence prediction: linear classification layer for standard amino acids
77
+ self.fc = nn.Linear(hidden_dim, num_classes)
78
+
79
+ def forward(self, data):
80
+ x_dict = data.x_dict
81
+ edge_index_dict = data.edge_index_dict
82
+ edge_attr_dict = data.edge_attr_dict
83
+
84
+ # 1. Expand features
85
+ x_dict_expanded = {
86
+ 'protein': self.protein_emb(x_dict['protein'])
87
+ }
88
+ if 'ligand' in x_dict:
89
+ x_dict_expanded['ligand'] = self.ligand_emb(x_dict['ligand'])
90
+ x_dict = x_dict_expanded
91
+
92
+ edge_attr_dict_expanded = {}
93
+ for edge_type, edge_attr in edge_attr_dict.items():
94
+ src_type, rel_type, dst_type = edge_type
95
+
96
+ # Generate 3D direction vectors symmetrically on-the-fly!
97
+ # We recalculate the distance dynamically from pos so that it structurally
98
+ # matches the vectors even when Gaussian noise is injected into pos during training.
99
+ src_pos = data[src_type].pos
100
+ dst_pos = data[dst_type].pos
101
+ src_idx, dst_idx = edge_index_dict[edge_type]
102
+
103
+ vec = dst_pos[dst_idx] - src_pos[src_idx]
104
+ dist_new = torch.norm(vec, dim=-1)
105
+
106
+ # Convert raw coordinates into purely directional unit vectors
107
+ vec_norm = vec / (dist_new.unsqueeze(-1) + 1e-7)
108
+
109
+ key = f"{src_type}__{rel_type}__{dst_type}"
110
+ if key in self.edge_embs:
111
+ # Expand standard scalar distance to [E, 16]
112
+ dist_smeared = self.edge_embs[key](dist_new)
113
+
114
+ # Combine length and direction! Output -> [E, 19]
115
+ edge_attr_dict_expanded[edge_type] = torch.cat([dist_smeared, vec_norm], dim=-1)
116
+ else:
117
+ edge_attr_dict_expanded[edge_type] = edge_attr
118
+
119
+ # 2. Iterative Message Passing through HeteroConv
120
+ for layer in self.layers:
121
+ x_dict = layer(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict_expanded)
122
+
123
+ # 3. Readout on Protein nodes
124
+ protein_x = x_dict['protein']
125
+ protein_x = self.norm_out(protein_x)
126
+ logits = self.fc(protein_x)
127
+
128
+ return logits