| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| from typing import Optional, Literal, Dict |
|
|
| import torch |
| import torch.nn as nn |
| from dgl import DGLGraph |
| from torch import Tensor |
|
|
| from se3_transformer.model.basis import get_basis, update_basis_with_fused |
| from se3_transformer.model.layers.attention import AttentionBlockSE3 |
| from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel |
| from se3_transformer.model.layers.norm import NormSE3 |
| from se3_transformer.model.layers.pooling import GPooling |
| from se3_transformer.runtime.utils import str2bool |
| from se3_transformer.model.fiber import Fiber |
|
|
|
|
| class Sequential(nn.Sequential): |
| """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """ |
|
|
| def forward(self, input, *args, **kwargs): |
| for module in self: |
| input = module(input, *args, **kwargs) |
| return input |
|
|
|
|
| def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None): |
| """ Add relative positions to existing edge features """ |
| edge_features = edge_features.copy() if edge_features else {} |
| r = relative_pos.norm(dim=-1, keepdim=True) |
| if '0' in edge_features: |
| edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1) |
| else: |
| edge_features['0'] = r[..., None] |
|
|
| return edge_features |
|
|
|
|
| class SE3Transformer(nn.Module): |
| def __init__(self, |
| num_layers: int, |
| fiber_in: Fiber, |
| fiber_hidden: Fiber, |
| fiber_out: Fiber, |
| num_heads: int, |
| channels_div: int, |
| fiber_edge: Fiber = Fiber({}), |
| return_type: Optional[int] = None, |
| pooling: Optional[Literal['avg', 'max']] = None, |
| norm: bool = True, |
| use_layer_norm: bool = True, |
| tensor_cores: bool = False, |
| low_memory: bool = False, |
| **kwargs): |
| """ |
| :param num_layers: Number of attention layers |
| :param fiber_in: Input fiber description |
| :param fiber_hidden: Hidden fiber description |
| :param fiber_out: Output fiber description |
| :param fiber_edge: Input edge fiber description |
| :param num_heads: Number of attention heads |
| :param channels_div: Channels division before feeding to attention layer |
| :param return_type: Return only features of this type |
| :param pooling: 'avg' or 'max' graph pooling before MLP layers |
| :param norm: Apply a normalization layer after each attention block |
| :param use_layer_norm: Apply layer normalization between MLP layers |
| :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) |
| :param low_memory: If True, will use slower ops that use less memory |
| """ |
| super().__init__() |
| self.num_layers = num_layers |
| self.fiber_edge = fiber_edge |
| self.num_heads = num_heads |
| self.channels_div = channels_div |
| self.return_type = return_type |
| self.pooling = pooling |
| self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) |
| self.tensor_cores = tensor_cores |
| self.low_memory = low_memory |
|
|
| if low_memory and not tensor_cores: |
| logging.warning('Low memory mode will have no effect with no Tensor Cores') |
|
|
| |
| fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL |
|
|
| graph_modules = [] |
| for i in range(num_layers): |
| graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in, |
| fiber_out=fiber_hidden, |
| fiber_edge=fiber_edge, |
| num_heads=num_heads, |
| channels_div=channels_div, |
| use_layer_norm=use_layer_norm, |
| max_degree=self.max_degree, |
| fuse_level=fuse_level)) |
| if norm: |
| graph_modules.append(NormSE3(fiber_hidden)) |
| fiber_in = fiber_hidden |
|
|
| graph_modules.append(ConvSE3(fiber_in=fiber_in, |
| fiber_out=fiber_out, |
| fiber_edge=fiber_edge, |
| self_interaction=True, |
| use_layer_norm=use_layer_norm, |
| max_degree=self.max_degree)) |
| self.graph_modules = Sequential(*graph_modules) |
|
|
| if pooling is not None: |
| assert return_type is not None, 'return_type must be specified when pooling' |
| self.pooling_module = GPooling(pool=pooling, feat_type=return_type) |
|
|
| def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor], |
| edge_feats: Optional[Dict[str, Tensor]] = None, |
| basis: Optional[Dict[str, Tensor]] = None): |
| |
| basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False, |
| use_pad_trick=self.tensor_cores and not self.low_memory, |
| amp=torch.is_autocast_enabled()) |
|
|
| |
| basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory, |
| fully_fused=self.tensor_cores and not self.low_memory) |
|
|
| edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats) |
|
|
| node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis) |
|
|
| if self.pooling is not None: |
| return self.pooling_module(node_feats, graph=graph) |
|
|
| if self.return_type is not None: |
| return node_feats[str(self.return_type)] |
|
|
| return node_feats |
|
|
| @staticmethod |
| def add_argparse_args(parser): |
| parser.add_argument('--num_layers', type=int, default=7, |
| help='Number of stacked Transformer layers') |
| parser.add_argument('--num_heads', type=int, default=8, |
| help='Number of heads in self-attention') |
| parser.add_argument('--channels_div', type=int, default=2, |
| help='Channels division before feeding to attention layer') |
| parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'], |
| help='Type of graph pooling') |
| parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False, |
| help='Apply a normalization layer after each attention block') |
| parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False, |
| help='Apply layer normalization between MLP layers') |
| parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False, |
| help='If true, will use fused ops that are slower but that use less memory ' |
| '(expect 25 percent less memory). ' |
| 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs') |
|
|
| return parser |
|
|
|
|
| class SE3TransformerPooled(nn.Module): |
| def __init__(self, |
| fiber_in: Fiber, |
| fiber_out: Fiber, |
| fiber_edge: Fiber, |
| num_degrees: int, |
| num_channels: int, |
| output_dim: int, |
| **kwargs): |
| super().__init__() |
| kwargs['pooling'] = kwargs['pooling'] or 'max' |
| self.transformer = SE3Transformer( |
| fiber_in=fiber_in, |
| fiber_hidden=Fiber.create(num_degrees, num_channels), |
| fiber_out=fiber_out, |
| fiber_edge=fiber_edge, |
| return_type=0, |
| **kwargs |
| ) |
|
|
| n_out_features = fiber_out.num_features |
| self.mlp = nn.Sequential( |
| nn.Linear(n_out_features, n_out_features), |
| nn.ReLU(), |
| nn.Linear(n_out_features, output_dim) |
| ) |
|
|
| def forward(self, graph, node_feats, edge_feats, basis=None): |
| feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1) |
| y = self.mlp(feats).squeeze(-1) |
| return y |
|
|
| @staticmethod |
| def add_argparse_args(parent_parser): |
| parser = parent_parser.add_argument_group("Model architecture") |
| SE3Transformer.add_argparse_args(parser) |
| parser.add_argument('--num_degrees', |
| help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]', |
| type=int, default=4) |
| parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32) |
| return parent_parser |
|
|