File size: 5,871 Bytes
dbc4151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
from torch_geometric.nn import TransformerConv, Linear, HeteroConv
import torch.nn.functional as F

class GaussianSmearing(nn.Module):
    """
    Expands a scalar distance tensor into a 16-dimensional Radial Basis Function (RBF) vector.
    This creates a rich representation for neural networks out of a single scalar distance.
    """
    def __init__(self, start=0.0, stop=8.0, num_gaussians=16):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        # Calculate the Gaussian coefficient directly handling width scale
        self.coeff = -0.5 / ((stop - start) / (num_gaussians - 1))**2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        # dist shape: [E, 1]
        # output shape: [E, 16]
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))

class ResidualAttentionBlock(nn.Module):
    def __init__(self, hidden_dim, edge_dim=19, dropout=0.1):
        super(ResidualAttentionBlock, self).__init__()
        # TransformerConv natively absorbs edge_attr into its attention mechanism.
        # We split the 128 hidden_dim across 4 attention heads (32 dim per head) to maintain exact tensor shapes.
        self.conv = TransformerConv(hidden_dim, hidden_dim // 4, heads=4, concat=True, edge_dim=edge_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, edge_attr):
        # Handle bipartite graphs (e.g. ['ligand', 'binds', 'protein'])
        if isinstance(x, tuple):
            x_src, x_dst = x
            identity = x_dst
        else:
            identity = x
            
        out = self.conv(x, edge_index, edge_attr)
        out = self.norm(out)
        out = F.relu(out)
        out = self.dropout(out)
        return out + identity

class Struct2SeqGNN(nn.Module):
    def __init__(self, node_features=6, ligand_features=6, hidden_dim=128, num_classes=21, num_layers=4, dropout=0.1):
        super(Struct2SeqGNN, self).__init__()
        
        # Initial node embeddings for distinct node types
        self.protein_emb = Linear(node_features, hidden_dim)
        self.ligand_emb = Linear(ligand_features, hidden_dim)
        
        # Distinct RBF distance expansions for each edge type
        # Decoupling these allows the network to scale bonds differently (e.g. covalent backbone vs weak ionic ligand bonds)
        self.edge_embs = nn.ModuleDict({
            'protein__interacts_with__protein': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16),
            'ligand__binds__protein': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16),
            'protein__binds__ligand': GaussianSmearing(start=0.0, stop=8.0, num_gaussians=16)
        })
        
        # Deep module list involving HeteroConv
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('protein', 'interacts_with', 'protein'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout),
                ('ligand', 'binds', 'protein'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout),
                ('protein', 'binds', 'ligand'): ResidualAttentionBlock(hidden_dim, edge_dim=19, dropout=dropout)
            }, aggr='sum')
            self.layers.append(conv)
        
        # Final layer normalization before classification (only applied to protein output)
        self.norm_out = nn.LayerNorm(hidden_dim)
        
        # Sequence prediction: linear classification layer for standard amino acids
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, data):
        x_dict = data.x_dict
        edge_index_dict = data.edge_index_dict
        edge_attr_dict = data.edge_attr_dict
        
        # 1. Expand features
        x_dict_expanded = {
            'protein': self.protein_emb(x_dict['protein'])
        }
        if 'ligand' in x_dict:
            x_dict_expanded['ligand'] = self.ligand_emb(x_dict['ligand'])
        x_dict = x_dict_expanded
        
        edge_attr_dict_expanded = {}
        for edge_type, edge_attr in edge_attr_dict.items():
            src_type, rel_type, dst_type = edge_type
            
            # Generate 3D direction vectors symmetrically on-the-fly!
            # We recalculate the distance dynamically from pos so that it structurally 
            # matches the vectors even when Gaussian noise is injected into pos during training.
            src_pos = data[src_type].pos
            dst_pos = data[dst_type].pos
            src_idx, dst_idx = edge_index_dict[edge_type]
            
            vec = dst_pos[dst_idx] - src_pos[src_idx]
            dist_new = torch.norm(vec, dim=-1)
            
            # Convert raw coordinates into purely directional unit vectors
            vec_norm = vec / (dist_new.unsqueeze(-1) + 1e-7)

            key = f"{src_type}__{rel_type}__{dst_type}"
            if key in self.edge_embs:
                # Expand standard scalar distance to [E, 16] 
                dist_smeared = self.edge_embs[key](dist_new)
                
                # Combine length and direction! Output -> [E, 19]
                edge_attr_dict_expanded[edge_type] = torch.cat([dist_smeared, vec_norm], dim=-1)
            else:
                edge_attr_dict_expanded[edge_type] = edge_attr
        
        # 2. Iterative Message Passing through HeteroConv
        for layer in self.layers:
            x_dict = layer(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict_expanded)
            
        # 3. Readout on Protein nodes
        protein_x = x_dict['protein']
        protein_x = self.norm_out(protein_x)
        logits = self.fc(protein_x)
        
        return logits