atlas-1-demo / src /layers /cfconv_message_passing.py
Reverb's picture
Upload folder using huggingface_hub
8eabce6 verified
"""
CFConv-based Message Passing
Complete implementation of continuous-filter message passing
with Gaussian RBF basis and learnable distance filters.
Based on SchNet and DimeNet architectures.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from torch_geometric.nn import MessagePassing
from .cfconv import GaussianSmearing, ShiftedSoftplus
class CFConvMessagePassing(MessagePassing):
"""
Continuous-filter message passing layer.
Uses Gaussian RBF basis to create learnable distance-dependent filters.
More expressive than simple MLPs for molecular interactions.
"""
def __init__(
self,
hidden_dim: int,
num_gaussians: int = 50,
cutoff: float = 5.0,
):
super().__init__(aggr='add')
self.hidden_dim = hidden_dim
self.cutoff = cutoff
# Gaussian RBF basis for distance expansion
self.distance_expansion = GaussianSmearing(
start=0.0,
stop=cutoff,
num_gaussians=num_gaussians,
)
# Filter-generating network (distance -> filter weights)
self.filter_network = nn.Sequential(
nn.Linear(num_gaussians, hidden_dim),
ShiftedSoftplus(),
nn.Linear(hidden_dim, hidden_dim),
ShiftedSoftplus(),
nn.Linear(hidden_dim, hidden_dim),
)
# Interaction blocks
self.node_mlp_1 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
ShiftedSoftplus(),
)
self.node_mlp_2 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
ShiftedSoftplus(),
)
# Layer normalization for stability
self.norm = nn.LayerNorm(hidden_dim)
# Initialize
self._init_weights()
def _init_weights(self):
"""Xavier initialization with moderate gain."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=0.5)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
edge_distances: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass with continuous-filter message passing.
Args:
x: Node features [num_nodes, hidden_dim]
edge_index: Edge indices [2, num_edges]
edge_attr: Edge features (spherical harmonics) [num_edges, sh_dim]
edge_distances: Edge distances [num_edges]
Returns:
x_out: Updated node features [num_nodes, hidden_dim]
"""
# Store input for residual
x_input = x
# Pre-process node features
x = self.node_mlp_1(x)
# Message passing with continuous filters
x_out = self.propagate(
edge_index,
x=x,
edge_distances=edge_distances,
)
# Post-process
x_out = self.node_mlp_2(x_out)
# Residual connection + normalization
x_out = self.norm(x_input + x_out)
return x_out
def message(
self,
x_j: torch.Tensor,
edge_distances: torch.Tensor,
) -> torch.Tensor:
"""
Create messages from neighbors using continuous filters.
Args:
x_j: Neighbor features [num_edges, hidden_dim]
edge_distances: Edge distances [num_edges]
Returns:
messages: [num_edges, hidden_dim]
"""
# Expand distances with Gaussian RBF
edge_features = self.distance_expansion(edge_distances) # [num_edges, num_gaussians]
# Generate continuous filters from distances
filters = self.filter_network(edge_features) # [num_edges, hidden_dim]
# Apply filters to neighbor features (element-wise product)
messages = x_j * filters
return messages
class InteractionBlock(nn.Module):
"""
SchNet-style interaction block.
Combines CFConv with residual connections and normalization.
"""
def __init__(
self,
hidden_dim: int,
num_gaussians: int = 50,
cutoff: float = 5.0,
):
super().__init__()
self.cfconv = CFConvMessagePassing(
hidden_dim=hidden_dim,
num_gaussians=num_gaussians,
cutoff=cutoff,
)
# Dense layers for residual pathway
self.dense_1 = nn.Linear(hidden_dim, hidden_dim)
self.dense_2 = nn.Linear(hidden_dim, hidden_dim)
self.activation = ShiftedSoftplus()
def forward(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
edge_distances: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass through interaction block.
Args:
x: Node features [num_nodes, hidden_dim]
edge_index: Edge indices [2, num_edges]
edge_attr: Edge features (not used in CFConv)
edge_distances: Edge distances [num_edges]
Returns:
x_out: Updated features [num_nodes, hidden_dim]
"""
# CFConv message passing
v = self.cfconv(x, edge_index, edge_attr, edge_distances)
# Dense transformation
v = self.dense_1(v)
v = self.activation(v)
v = self.dense_2(v)
# Residual
x_out = x + v
return x_out