Spaces:
Build error
Build error
| """ | |
| CFConv-based Message Passing | |
| Complete implementation of continuous-filter message passing | |
| with Gaussian RBF basis and learnable distance filters. | |
| Based on SchNet and DimeNet architectures. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Tuple | |
| from torch_geometric.nn import MessagePassing | |
| from .cfconv import GaussianSmearing, ShiftedSoftplus | |
| class CFConvMessagePassing(MessagePassing): | |
| """ | |
| Continuous-filter message passing layer. | |
| Uses Gaussian RBF basis to create learnable distance-dependent filters. | |
| More expressive than simple MLPs for molecular interactions. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_gaussians: int = 50, | |
| cutoff: float = 5.0, | |
| ): | |
| super().__init__(aggr='add') | |
| self.hidden_dim = hidden_dim | |
| self.cutoff = cutoff | |
| # Gaussian RBF basis for distance expansion | |
| self.distance_expansion = GaussianSmearing( | |
| start=0.0, | |
| stop=cutoff, | |
| num_gaussians=num_gaussians, | |
| ) | |
| # Filter-generating network (distance -> filter weights) | |
| self.filter_network = nn.Sequential( | |
| nn.Linear(num_gaussians, hidden_dim), | |
| ShiftedSoftplus(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ShiftedSoftplus(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| # Interaction blocks | |
| self.node_mlp_1 = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ShiftedSoftplus(), | |
| ) | |
| self.node_mlp_2 = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ShiftedSoftplus(), | |
| ) | |
| # Layer normalization for stability | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| # Initialize | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Xavier initialization with moderate gain.""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight, gain=0.5) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| edge_attr: torch.Tensor, | |
| edge_distances: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass with continuous-filter message passing. | |
| Args: | |
| x: Node features [num_nodes, hidden_dim] | |
| edge_index: Edge indices [2, num_edges] | |
| edge_attr: Edge features (spherical harmonics) [num_edges, sh_dim] | |
| edge_distances: Edge distances [num_edges] | |
| Returns: | |
| x_out: Updated node features [num_nodes, hidden_dim] | |
| """ | |
| # Store input for residual | |
| x_input = x | |
| # Pre-process node features | |
| x = self.node_mlp_1(x) | |
| # Message passing with continuous filters | |
| x_out = self.propagate( | |
| edge_index, | |
| x=x, | |
| edge_distances=edge_distances, | |
| ) | |
| # Post-process | |
| x_out = self.node_mlp_2(x_out) | |
| # Residual connection + normalization | |
| x_out = self.norm(x_input + x_out) | |
| return x_out | |
| def message( | |
| self, | |
| x_j: torch.Tensor, | |
| edge_distances: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Create messages from neighbors using continuous filters. | |
| Args: | |
| x_j: Neighbor features [num_edges, hidden_dim] | |
| edge_distances: Edge distances [num_edges] | |
| Returns: | |
| messages: [num_edges, hidden_dim] | |
| """ | |
| # Expand distances with Gaussian RBF | |
| edge_features = self.distance_expansion(edge_distances) # [num_edges, num_gaussians] | |
| # Generate continuous filters from distances | |
| filters = self.filter_network(edge_features) # [num_edges, hidden_dim] | |
| # Apply filters to neighbor features (element-wise product) | |
| messages = x_j * filters | |
| return messages | |
| class InteractionBlock(nn.Module): | |
| """ | |
| SchNet-style interaction block. | |
| Combines CFConv with residual connections and normalization. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_gaussians: int = 50, | |
| cutoff: float = 5.0, | |
| ): | |
| super().__init__() | |
| self.cfconv = CFConvMessagePassing( | |
| hidden_dim=hidden_dim, | |
| num_gaussians=num_gaussians, | |
| cutoff=cutoff, | |
| ) | |
| # Dense layers for residual pathway | |
| self.dense_1 = nn.Linear(hidden_dim, hidden_dim) | |
| self.dense_2 = nn.Linear(hidden_dim, hidden_dim) | |
| self.activation = ShiftedSoftplus() | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| edge_attr: torch.Tensor, | |
| edge_distances: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass through interaction block. | |
| Args: | |
| x: Node features [num_nodes, hidden_dim] | |
| edge_index: Edge indices [2, num_edges] | |
| edge_attr: Edge features (not used in CFConv) | |
| edge_distances: Edge distances [num_edges] | |
| Returns: | |
| x_out: Updated features [num_nodes, hidden_dim] | |
| """ | |
| # CFConv message passing | |
| v = self.cfconv(x, edge_index, edge_attr, edge_distances) | |
| # Dense transformation | |
| v = self.dense_1(v) | |
| v = self.activation(v) | |
| v = self.dense_2(v) | |
| # Residual | |
| x_out = x + v | |
| return x_out | |