Spaces:
Build error
Build error
| """ | |
| Continuous-Filter Convolution (CFConv) | |
| SchNet-style continuous-filter convolutions for molecular graphs. | |
| Uses learnable distance filters with Gaussian RBF basis. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Optional | |
| import math | |
| class CFConv(nn.Module): | |
| """ | |
| Continuous-filter convolution layer (SchNet-style). | |
| Uses Gaussian RBF basis to create learnable distance-dependent filters. | |
| More expressive than simple MLPs for molecular interactions. | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| num_gaussians: int = 50, | |
| cutoff: float = 5.0, | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.cutoff = cutoff | |
| # Gaussian RBF basis for distance expansion | |
| self.distance_expansion = GaussianSmearing( | |
| start=0.0, | |
| stop=cutoff, | |
| num_gaussians=num_gaussians, | |
| ) | |
| # Filter-generating network | |
| self.filter_network = nn.Sequential( | |
| nn.Linear(num_gaussians, in_features), | |
| nn.SiLU(), | |
| nn.Linear(in_features, in_features * out_features), | |
| ) | |
| # Dense layer for node features | |
| self.dense = nn.Linear(in_features, out_features) | |
| # Initialize | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Xavier initialization.""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight, gain=0.5) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| edge_index: torch.Tensor, | |
| edge_distances: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass with continuous filters. | |
| Args: | |
| x: Node features [num_nodes, in_features] | |
| edge_index: Edge indices [2, num_edges] | |
| edge_distances: Edge distances [num_edges] | |
| Returns: | |
| output: Transformed features [num_nodes, out_features] | |
| """ | |
| row, col = edge_index | |
| # Expand distances with Gaussian RBF | |
| edge_features = self.distance_expansion(edge_distances) | |
| # Generate filters from distances | |
| filters = self.filter_network(edge_features) | |
| filters = filters.view(-1, self.in_features, self.out_features) | |
| # Apply filters to neighbor features | |
| neighbor_features = x[col] | |
| # Continuous convolution | |
| messages = torch.einsum('ei,eio->eo', neighbor_features, filters) | |
| # Aggregate messages | |
| output = torch.zeros(x.size(0), self.out_features, device=x.device) | |
| output.index_add_(0, row, messages) | |
| # Add dense transformation | |
| output = output + self.dense(x) | |
| return output | |
| class GaussianSmearing(nn.Module): | |
| """ | |
| Gaussian distance expansion (SchNet-style). | |
| Expands distances into Gaussian basis functions. | |
| """ | |
| def __init__( | |
| self, | |
| start: float = 0.0, | |
| stop: float = 5.0, | |
| num_gaussians: int = 50, | |
| ): | |
| super().__init__() | |
| # Gaussian centers (evenly spaced) | |
| offset = torch.linspace(start, stop, num_gaussians) | |
| self.register_buffer('offset', offset) | |
| # Gaussian width (inverse variance) | |
| coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 | |
| self.register_buffer('coeff', torch.tensor(coeff)) | |
| def forward(self, distances: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Expand distances into Gaussian basis. | |
| Args: | |
| distances: [num_edges] or [num_edges, 1] | |
| Returns: | |
| expanded: [num_edges, num_gaussians] | |
| """ | |
| if distances.dim() == 1: | |
| distances = distances.unsqueeze(-1) | |
| # Gaussian RBF: exp(-coeff * (d - offset)^2) | |
| return torch.exp(self.coeff * (distances - self.offset) ** 2) | |
| class ShiftedSoftplus(nn.Module): | |
| """ | |
| Shifted softplus activation (SchNet-style). | |
| Ensures smooth gradients and zero at origin. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.shift = torch.log(torch.tensor(2.0)).item() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return nn.functional.softplus(x) - self.shift | |