atlas-1-demo / src /layers /cfconv.py
Reverb's picture
Upload folder using huggingface_hub
8eabce6 verified
"""
Continuous-Filter Convolution (CFConv)
SchNet-style continuous-filter convolutions for molecular graphs.
Uses learnable distance filters with Gaussian RBF basis.
"""
import torch
import torch.nn as nn
from typing import Tuple, Optional
import math
class CFConv(nn.Module):
"""
Continuous-filter convolution layer (SchNet-style).
Uses Gaussian RBF basis to create learnable distance-dependent filters.
More expressive than simple MLPs for molecular interactions.
"""
def __init__(
self,
in_features: int,
out_features: int,
num_gaussians: int = 50,
cutoff: float = 5.0,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.cutoff = cutoff
# Gaussian RBF basis for distance expansion
self.distance_expansion = GaussianSmearing(
start=0.0,
stop=cutoff,
num_gaussians=num_gaussians,
)
# Filter-generating network
self.filter_network = nn.Sequential(
nn.Linear(num_gaussians, in_features),
nn.SiLU(),
nn.Linear(in_features, in_features * out_features),
)
# Dense layer for node features
self.dense = nn.Linear(in_features, out_features)
# Initialize
self._init_weights()
def _init_weights(self):
"""Xavier initialization."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=0.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_distances: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass with continuous filters.
Args:
x: Node features [num_nodes, in_features]
edge_index: Edge indices [2, num_edges]
edge_distances: Edge distances [num_edges]
Returns:
output: Transformed features [num_nodes, out_features]
"""
row, col = edge_index
# Expand distances with Gaussian RBF
edge_features = self.distance_expansion(edge_distances)
# Generate filters from distances
filters = self.filter_network(edge_features)
filters = filters.view(-1, self.in_features, self.out_features)
# Apply filters to neighbor features
neighbor_features = x[col]
# Continuous convolution
messages = torch.einsum('ei,eio->eo', neighbor_features, filters)
# Aggregate messages
output = torch.zeros(x.size(0), self.out_features, device=x.device)
output.index_add_(0, row, messages)
# Add dense transformation
output = output + self.dense(x)
return output
class GaussianSmearing(nn.Module):
"""
Gaussian distance expansion (SchNet-style).
Expands distances into Gaussian basis functions.
"""
def __init__(
self,
start: float = 0.0,
stop: float = 5.0,
num_gaussians: int = 50,
):
super().__init__()
# Gaussian centers (evenly spaced)
offset = torch.linspace(start, stop, num_gaussians)
self.register_buffer('offset', offset)
# Gaussian width (inverse variance)
coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
self.register_buffer('coeff', torch.tensor(coeff))
def forward(self, distances: torch.Tensor) -> torch.Tensor:
"""
Expand distances into Gaussian basis.
Args:
distances: [num_edges] or [num_edges, 1]
Returns:
expanded: [num_edges, num_gaussians]
"""
if distances.dim() == 1:
distances = distances.unsqueeze(-1)
# Gaussian RBF: exp(-coeff * (d - offset)^2)
return torch.exp(self.coeff * (distances - self.offset) ** 2)
class ShiftedSoftplus(nn.Module):
"""
Shifted softplus activation (SchNet-style).
Ensures smooth gradients and zero at origin.
"""
def __init__(self):
super().__init__()
self.shift = torch.log(torch.tensor(2.0)).item()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.softplus(x) - self.shift