| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_scatter import scatter |
| from torch_geometric.nn import InstanceNorm |
|
|
| class EGNNLayer(nn.Module): |
| """ |
| EGNN layer with optional feed forward network and batch normalization. |
| |
| Args: |
| input_nf: Number of input node features |
| output_nf: Number of output node features |
| hidden_nf: Number of hidden features |
| edges_in_d: Number of input edge features |
| act_fn: Activation function |
| residual: Whether to use residual connections |
| attention: Whether to use attention mechanism for edge features |
| normalize: Whether to normalize coordinates |
| coords_agg: Aggregation method for coordinates (mean, sum, max, min) |
| tanh: Whether to use tanh activation for coordinate updates |
| dropout: Dropout rate |
| ffn: Whether to use feed forward network |
| batch_norm: Whether to use batch normalization |
| """ |
| def __init__(self, input_nf, output_nf, hidden_nf, |
| edges_in_d=0, act_fn=nn.SiLU(), |
| residual=True, attention=False, normalize=False, |
| coords_agg='mean', tanh=False, dropout=0.0, |
| ffn=False, batch_norm=True): |
| super().__init__() |
| self.input_nf = input_nf |
| self.output_nf = output_nf |
| self.hidden_nf = hidden_nf |
| self.residual = residual |
| self.attention = attention |
| self.normalize = normalize |
| self.coords_agg = coords_agg |
| self.tanh = tanh |
| self.epsilon = 1e-8 |
| self.dropout = dropout |
| self.ffn = ffn |
| self.batch_norm = batch_norm |
|
|
| |
| in_edge = input_nf*2 + 1 + edges_in_d |
| self.edge_mlp = nn.Sequential( |
| nn.Linear(in_edge, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| nn.Linear(hidden_nf, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| ) |
| if attention: |
| self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid()) |
|
|
| |
| layer = nn.Linear(hidden_nf,1, bias=False) |
| nn.init.xavier_uniform_(layer.weight, gain=0.001) |
| coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn, |
| nn.Dropout(dropout), layer] |
| if tanh: coord_blocks.append(nn.Tanh()) |
| self.coord_mlp = nn.Sequential(*coord_blocks) |
|
|
| |
| self.node_mlp = nn.Sequential( |
| nn.Linear(hidden_nf + input_nf, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| nn.Linear(hidden_nf, output_nf), |
| ) |
|
|
| |
| if batch_norm: |
| self.norm_node = InstanceNorm(output_nf, affine=True) |
| self.norm_coord = InstanceNorm(3, affine=True) |
|
|
| |
| if ffn: |
| self.ff1 = nn.Linear(output_nf, output_nf*2) |
| self.ff2 = nn.Linear(output_nf*2, output_nf) |
| self.act_ff = act_fn |
| self.drop_ff = nn.Dropout(dropout) |
| if batch_norm: |
| self.norm_ff1 = InstanceNorm(output_nf, affine=True) |
| self.norm_ff2 = InstanceNorm(output_nf, affine=True) |
|
|
| def coord2radial(self, edge_index, coord): |
| row, col = edge_index |
| diff = coord[row] - coord[col] |
| dist2 = (diff**2).sum(dim=-1, keepdim=True) |
| |
| |
| dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0) |
| |
| if self.normalize: |
| norm = (dist2.sqrt().detach() + self.epsilon) |
| diff = diff / norm |
| |
| diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff)) |
| return dist2, diff |
|
|
| def _ff_block(self, x): |
| """Feed Forward block. |
| """ |
| x = self.drop_ff(self.act_ff(self.ff1(x))) |
| return self.ff2(x) |
| |
| def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None): |
| row, col = edge_index |
| radial, coord_diff = self.coord2radial(edge_index, coord) |
|
|
| |
| e_in = [h[row], h[col], radial] |
| if edge_attr is not None: e_in.append(edge_attr) |
| e = torch.cat(e_in, dim=-1) |
| e = self.edge_mlp(e) |
| if self.attention: |
| att = self.att_mlp(e) |
| e = e * att |
|
|
| |
| coord_update = self.coord_mlp(e) |
| |
| coord_update = torch.clamp(coord_update, -1.0, 1.0) |
| trans = coord_diff * coord_update |
| |
| |
| trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans)) |
| |
| agg_coord = scatter(trans, row, dim=0, |
| dim_size=coord.size(0), |
| reduce=self.coords_agg) |
| coord = coord + agg_coord |
| |
| |
| coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord)) |
| |
| if self.batch_norm: |
| coord = self.norm_coord(coord, batch) |
|
|
| |
| agg_node = scatter(e, row, dim=0, |
| dim_size=h.size(0), reduce='sum') |
| x_in = torch.cat([h, agg_node], dim=-1) |
| if node_attr is not None: |
| x_in = torch.cat([x_in, node_attr], dim=-1) |
| h_new = self.node_mlp(x_in) |
| if self.batch_norm: |
| h_new = self.norm_node(h_new, batch) |
| if self.residual and h_new.shape[-1] == h.shape[-1]: |
| h_new = h + h_new |
|
|
| |
| if self.ffn: |
| if self.batch_norm: |
| h_new = self.norm_ff1(h_new, batch) |
| h_new = h_new + self._ff_block(h_new) |
| if self.batch_norm: |
| h_new = self.norm_ff2(h_new, batch) |
|
|
| return h_new, coord, e |
|
|
| class EGNNLayer2(nn.Module): |
| """ |
| EGNN layer with optional feed forward network and batch normalization. |
| |
| Args: |
| input_nf: Number of input node features |
| output_nf: Number of output node features |
| hidden_nf: Number of hidden features |
| edges_in_d: Number of input edge features |
| act_fn: Activation function |
| residual: Whether to use residual connections |
| attention: Whether to use attention mechanism for edge features |
| normalize: Whether to normalize coordinates |
| coords_agg: Aggregation method for coordinates (mean, sum, max, min) |
| tanh: Whether to use tanh activation for coordinate updates |
| dropout: Dropout rate |
| ffn: Whether to use feed forward network |
| batch_norm: Whether to use batch normalization |
| """ |
| def __init__(self, input_nf, output_nf, hidden_nf, |
| edges_in_d=0, act_fn=nn.SiLU(), |
| residual=True, attention=False, normalize=False, |
| coords_agg='mean', tanh=False, dropout=0.0, |
| ffn=False, batch_norm=True): |
| super().__init__() |
| self.input_nf = input_nf |
| self.output_nf = output_nf |
| self.hidden_nf = hidden_nf |
| self.residual = residual |
| self.attention = attention |
| self.normalize = normalize |
| self.coords_agg = coords_agg |
| self.tanh = tanh |
| self.epsilon = 1e-8 |
| self.dropout = dropout |
| self.ffn = ffn |
| self.batch_norm = batch_norm |
|
|
| |
| in_edge = input_nf*2 + 1 + edges_in_d |
| self.edge_mlp = nn.Sequential( |
| nn.Linear(in_edge, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| nn.Linear(hidden_nf, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| ) |
| if attention: |
| self.att_mlp = nn.Sequential(nn.Linear(hidden_nf,1), nn.Sigmoid()) |
|
|
| |
| layer = nn.Linear(hidden_nf,1, bias=False) |
| nn.init.xavier_uniform_(layer.weight, gain=0.001) |
| coord_blocks = [nn.Linear(hidden_nf, hidden_nf), act_fn, |
| nn.Dropout(dropout), layer] |
| if tanh: coord_blocks.append(nn.Tanh()) |
| self.coord_mlp = nn.Sequential(*coord_blocks) |
|
|
| |
| self.node_mlp = nn.Sequential( |
| nn.Linear(hidden_nf + input_nf, hidden_nf), |
| act_fn, nn.Dropout(dropout), |
| nn.Linear(hidden_nf, output_nf), |
| ) |
|
|
| |
| if batch_norm: |
| self.norm_node = InstanceNorm(output_nf, affine=True) |
| self.norm_coord = InstanceNorm(3, affine=True) |
|
|
| |
| if ffn: |
| self.ff1 = nn.Linear(output_nf, output_nf*2) |
| self.ff2 = nn.Linear(output_nf*2, output_nf) |
| self.act_ff = act_fn |
| self.drop_ff = nn.Dropout(dropout) |
| if batch_norm: |
| self.norm_ff1 = InstanceNorm(output_nf, affine=True) |
| self.norm_ff2 = InstanceNorm(output_nf, affine=True) |
|
|
| def coord2radial(self, edge_index, coord): |
| row, col = edge_index |
| diff = coord[row] - coord[col] |
| dist2 = (diff**2).sum(dim=-1, keepdim=True) |
| |
| |
| dist2 = torch.clamp(dist2, min=self.epsilon, max=100.0) |
| |
| if self.normalize: |
| norm = (dist2.sqrt().detach() + self.epsilon) |
| diff = diff / norm |
| |
| diff = torch.where(torch.isfinite(diff), diff, torch.zeros_like(diff)) |
| return dist2, diff |
|
|
| def _ff_block(self, x): |
| """Feed Forward block. |
| """ |
| x = self.drop_ff(self.act_ff(self.ff1(x))) |
| return self.ff2(x) |
| |
| def forward(self, h, coord, edge_index, batch, edge_attr=None, node_attr=None): |
| row, col = edge_index |
| radial, coord_diff = self.coord2radial(edge_index, coord) |
|
|
| |
| e_in = [h[row], h[col], radial] |
| if edge_attr is not None: e_in.append(edge_attr) |
| e = torch.cat(e_in, dim=-1) |
| e = self.edge_mlp(e) |
| if self.attention: |
| att = self.att_mlp(e) |
| e = e * att |
|
|
| |
| coord_update = self.coord_mlp(e) |
| |
| coord_update = torch.clamp(coord_update, -1.0, 1.0) |
| trans = coord_diff * coord_update |
| |
| |
| trans = torch.where(torch.isfinite(trans), trans, torch.zeros_like(trans)) |
| |
| agg_coord = scatter(trans, row, dim=0, |
| dim_size=coord.size(0), |
| reduce=self.coords_agg) |
| coord = coord + agg_coord |
| |
| |
| coord = torch.where(torch.isfinite(coord), coord, torch.zeros_like(coord)) |
| |
| if self.batch_norm: |
| coord = self.norm_coord(coord, batch) |
|
|
| |
| agg_node = scatter(e, row, dim=0, |
| dim_size=h.size(0), reduce='sum') |
| x_in = torch.cat([h, agg_node], dim=-1) |
| if node_attr is not None: |
| x_in = torch.cat([x_in, node_attr], dim=-1) |
| h_new = self.node_mlp(x_in) |
| if self.batch_norm: |
| h_new = self.norm_node(h_new, batch) |
| if self.residual and h_new.shape[-1] == h.shape[-1]: |
| h_new = h + h_new |
|
|
| |
| if self.ffn: |
| if self.batch_norm: |
| h_new = self.norm_ff1(h_new, batch) |
| h_new = h_new + self._ff_block(h_new) |
| if self.batch_norm: |
| h_new = self.norm_ff2(h_new, batch) |
|
|
| return h_new, coord, e |