English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
from torch import nn
from src.utils import listify_with_reference
from src.nn import Stage, PointStage, DownNFuseStage, UpNFuseStage, \
BatchNorm, CatFusion, MLP, LayerNorm
from src.nn.pool import BaseAttentivePool
from src.nn.pool import pool_factory
__all__ = ['SPT']
class SPT(nn.Module):
"""Superpoint Transformer. A UNet-like architecture processing NAG.
The architecture can be (roughly) summarized as:
p_0, x_0 --------- PointStage
\
p_1, x_1, e_1 -- DownNFuseStage_1 ------- UpNFuseStage_1 --> out_1
\ |
p_2, x_2, e_2 -- DownNFuseStage_2 ------- UpNFuseStage_2 --> out_2
\ |
... ...
Where:
- p_0: point positions
- x_0: input point features (handcrafted)
- p_i: node positions (i.e. superpoint centroid) at level i
- x_i: input node features (handcrafted superpoint features) at
level i
- e_i: input edge features (handcrafted horizontal superpoint graph
edge features) at level i
- out_i: node-wise output features at level i
:param point_mlp: List[int]
Channels for the input MLP of the `PointStage`
:param point_drop: float
Dropout rate for the last layer of the input and output MLPs
in `PointStage`
:param nano: bool
If True, the `PointStage` will be removed and the model will
only operate on superpoints, without extracting features
from the points. This lightweight model saves compute and
memory, at the potential expense of high-resolution
reasoning
:param down_dim: List[int], int
Feature dimension for each `DownNFuseStage` (i.e. not
including the `PointStage` when `nano=False`)
:param down_pool_dim: List[str], str
Pooling mechanism used for the down-pooling in each
`DownNFuseStage`. Supports 'max', 'min', 'mean', and 'sum'.
See `pool_factory()` for more
:param down_in_mlp: List[List[int]], List[int]
Channels for the input MLP of each `DownNFuseStage`
:param down_out_mlp: List[List[int]], List[int]
Channels for the output MLP of each `DownNFuseStage`. The
first channel for each stage must match with what is passed
in `down_dim`
:param down_mlp_drop: List[float], float
Dropout rate for the last layer of the input and output MLPs
of each `DownNFuseStage`
:param down_num_heads: List[int], int
Number of self-attention heads for each `DownNFuseStage
:param down_num_blocks: List[int], int
Number of self-attention blocks for each `DownNFuseStage
:param down_ffn_ratio: List[float], float
Multiplicative factor for computing the dimension of the
`FFN` inverted bottleneck, for each `DownNFuseStage. See
`TransformerBlock`
:param down_residual_drop: List[float], float
Dropout on the output self-attention features for each
`DownNFuseStage`. See `TransformerBlock`
:param down_attn_drop: List[float], float
Dropout on the self-attention weights for each
`DownNFuseStage`. See `TransformerBlock`
:param down_drop_path: List[float], float
Dropout on the residual paths for each `DownNFuseStage`. See
`TransformerBlock`
:param up_dim: List[int], int
Feature dimension for each `UpNFuseStage`
:param up_in_mlp: List[List[int]], List[int]
Channels for the input MLP of each `UpNFuseStage`
:param up_out_mlp: List[List[int]], List[int]
Channels for the output MLP of each `UpNFuseStage`. The
first channel for each stage must match with what is passed
in `up_dim`
:param up_mlp_drop: List[float], float
Dropout rate for the last layer of the input and output MLPs
of each `UpNFuseStage`
:param up_num_heads: List[int], int
Number of self-attention heads for each `UpNFuseStage
:param up_num_blocks: List[int], int
Number of self-attention blocks for each `UpNFuseStage
:param up_ffn_ratio: List[float], float
Multiplicative factor for computing the dimension of the
`FFN` inverted bottleneck, for each `UpNFuseStage. See
`TransformerBlock`
:param up_residual_drop: List[float], float
Dropout on the output self-attention features for each
`UpNFuseStage`. See `TransformerBlock`
:param up_attn_drop: List[float], float
Dropout on the self-attention weights for each
`UpNFuseStage`. See `TransformerBlock`
:param up_drop_path: List[float], float
Dropout on the residual paths for each `UpNFuseStage`. See
`TransformerBlock`
:param node_mlp: List[int]
Channels for the MLPs that will encode handcrafted node
(i.e. segment, superpoint) features. These will be called
before each `DownNFuseStage` and their output will be
concatenated to any already-existing features and passed
to `DownNFuseStage` and `UpNFuseStage`. For the special case
the `nano=True` model, the first MLP will be run before the
first `Stage` too
:param h_edge_mlp: List[int]
Channels for the MLPs that will encode handcrafted
horizontal edge (i.e. edges in the superpoint adjacency
graph at each partition level) features. These will be
called before each `DownNFuseStage` and their output will be
passed as `edge_attr` to `DownNFuseStage` and `UpNFuseStage`
:param v_edge_mlp: List[int]
Channels for the MLPs that will encode handcrafted
vertical edge (i.e. edges connecting nodes to their parent
in the above partition level) features. These will be
called before each `DownNFuseStage` and their output will be
passed as `v_edge_attr` to `DownNFuseStage` and
`UpNFuseStage`
:param mlp_activation: nn.Module
Activation function used for all MLPs throughout the
architecture
:param mlp_norm: n.Module
Normalization function for all MLPs throughout the
architecture
:param qk_dim: int
Dimension of the queries and keys. See `SelfAttentionBlock`
:param qkv_bias: bool
Whether the linear layers producing queries, keys, and
values should have a bias. See `SelfAttentionBlock`
:param qk_scale: str
Scaling applied to the query*key product before the softmax.
More specifically, one may want to normalize the query-key
compatibilities based on the number of dimensions (referred
to as 'd' here) as in a vanilla Transformer implementation,
or based on the number of neighbors each node has in the
attention graph (referred to as 'g' here). If nothing is
specified the scaling will be `1 / (sqrt(d) * sqrt(g))`,
which is equivalent to passing `'d.g'`. Passing `'d+g'` will
yield `1 / (sqrt(d) + sqrt(g))`. Meanwhile, passing 'd' will
yield `1 / sqrt(d)`, and passing `'g'` will yield
`1 / sqrt(g)`. See `SelfAttentionBlock`
:param in_rpe_dim:
:param activation: nn.Module
Activation function used in the `FFN` modules. See
`TransformerBlock`
:param norm: nn.Module
Normalization function for the `FFN` module. See
`TransformerBlock`
:param pre_norm: bool
Whether the normalization should be applied before or after
the `SelfAttentionBlock` and `FFN` in the residual branches.
See`TransformerBlock`
:param no_sa: bool
Whether a self-attention residual branch should be used at
all. See`TransformerBlock`
:param no_ffn: bool
Whether a feed-forward residual branch should be used at
all. See`TransformerBlock`
:param k_rpe: bool
Whether keys should receive relative positional encodings
computed from edge features. See `SelfAttentionBlock`
:param q_rpe: bool
Whether queries should receive relative positional encodings
computed from edge features. See `SelfAttentionBlock`
:param v_rpe: bool
Whether values should receive relative positional encodings
computed from edge features. See `SelfAttentionBlock`
:param k_delta_rpe: bool
Whether keys should receive relative positional encodings
computed from the difference between source and target node
features. See `SelfAttentionBlock`
:param q_delta_rpe: bool
Whether queries should receive relative positional encodings
computed from the difference between source and target node
features. See `SelfAttentionBlock`
:param qk_share_rpe: bool
Whether queries and keys should use the same parameters for
building relative positional encodings. See
`SelfAttentionBlock`
:param q_on_minus_rpe: bool
Whether relative positional encodings for queries should be
computed on the opposite of features used for keys. This allows,
for instance, to break the symmetry when `qk_share_rpe` but we
want relative positional encodings to capture different meanings
for keys and queries. See `SelfAttentionBlock`
:param share_hf_mlps: bool
Whether stages should share the MLPs for encoding
handcrafted node, horizontal edge, and vertical edge
features
:param stages_share_rpe: bool
Whether all `Stage`s should share the same parameters for
building relative positional encodings
:param blocks_share_rpe: bool
Whether all the `TransformerBlock` in the same `Stage`
should share the same parameters for building relative
positional encodings
:param heads_share_rpe: bool
Whether attention heads should share the same parameters for
building relative positional encodings
:param use_pos: bool
Whether the node's position (normalized with `UnitSphereNorm`)
should be concatenated to the features. See `Stage`
:param use_node_hf: bool
Whether handcrafted node (i.e. segment, superpoint) features
should be used at all. If False, `node_mlp` will be ignored
:param use_diameter: bool
Whether the node's diameter (currently estimated with
`UnitSphereNorm`) should be concatenated to the node features.
See `Stage`
:param use_diameter_parent: bool
Whether the node's parent diameter (currently estimated with
`UnitSphereNorm`) should be concatenated to the node features.
See `Stage`
:param pool: str, nn.Module
Pooling mechanism for `DownNFuseStage`s. Supports 'max',
'min', 'mean', 'sum' for string arguments.
See `pool_factory()` for more
:param unpool: str
Unpooling mechanism for `UpNFuseStage`s. Only supports
'index' for now
:param fusion: str
Fusion mechanism used in `DownNFuseStage` and `UpNFuseStage`
to merge node features from different branches. Supports
'cat', 'residual', 'first', 'second'. See `fusion_factory()`
for more
:param norm_mode: str
Indexing mode used for feature normalization. This will be
passed to `Data.norm_index()`. 'graph' will normalize
features per graph (i.e. per cloud, i.e. per batch item).
'node' will normalize per node (i.e. per point). 'segment'
will normalize per segment (i.e. per cluster)
:param output_stage_wise: bool
If True, the output contain the features for each node of
each partition 1+ level. IF False, only the features for the
partition level 1 will be returned. Note we do not compute
the features for level 0, since the entire goal of this
superpoint-based reasoning is to mitigate compute and memory
by circumventing the need to manipulate such full-resolution
objects
"""
def __init__(
self,
point_mlp=None,
point_drop=None,
nano=False,
down_dim=None,
down_pool_dim=None,
down_in_mlp=None,
down_out_mlp=None,
down_mlp_drop=None,
down_num_heads=1,
down_num_blocks=0,
down_ffn_ratio=4,
down_residual_drop=None,
down_attn_drop=None,
down_drop_path=None,
up_dim=None,
up_in_mlp=None,
up_out_mlp=None,
up_mlp_drop=None,
up_num_heads=1,
up_num_blocks=0,
up_ffn_ratio=4,
up_residual_drop=None,
up_attn_drop=None,
up_drop_path=None,
node_mlp=None,
h_edge_mlp=None,
v_edge_mlp=None,
mlp_activation=nn.LeakyReLU(),
mlp_norm=BatchNorm,
qk_dim=8,
qkv_bias=True,
qk_scale=None,
in_rpe_dim=18,
activation=nn.LeakyReLU(),
norm=LayerNorm,
pre_norm=True,
no_sa=False,
no_ffn=False,
k_rpe=False,
q_rpe=False,
v_rpe=False,
k_delta_rpe=False,
q_delta_rpe=False,
qk_share_rpe=False,
q_on_minus_rpe=False,
share_hf_mlps=False,
stages_share_rpe=False,
blocks_share_rpe=False,
heads_share_rpe=False,
use_pos=True,
use_node_hf=True,
use_diameter=False,
use_diameter_parent=False,
pool='max',
unpool='index',
fusion='cat',
norm_mode='graph',
output_stage_wise=False):
super().__init__()
self.nano = nano
self.use_pos = use_pos
self.use_node_hf = use_node_hf
self.use_diameter = use_diameter
self.use_diameter_parent = use_diameter_parent
self.norm_mode = norm_mode
self.stages_share_rpe = stages_share_rpe
self.blocks_share_rpe = blocks_share_rpe
self.heads_share_rpe = heads_share_rpe
self.output_stage_wise = output_stage_wise
# Convert input arguments to nested lists
(
down_dim,
down_pool_dim,
down_in_mlp,
down_out_mlp,
down_mlp_drop,
down_num_heads,
down_num_blocks,
down_ffn_ratio,
down_residual_drop,
down_attn_drop,
down_drop_path
) = listify_with_reference(
down_dim,
down_pool_dim,
down_in_mlp,
down_out_mlp,
down_mlp_drop,
down_num_heads,
down_num_blocks,
down_ffn_ratio,
down_residual_drop,
down_attn_drop,
down_drop_path)
(
up_dim,
up_in_mlp,
up_out_mlp,
up_mlp_drop,
up_num_heads,
up_num_blocks,
up_ffn_ratio,
up_residual_drop,
up_attn_drop,
up_drop_path
) = listify_with_reference(
up_dim,
up_in_mlp,
up_out_mlp,
up_mlp_drop,
up_num_heads,
up_num_blocks,
up_ffn_ratio,
up_residual_drop,
up_attn_drop,
up_drop_path)
# Local helper variables describing the architecture
num_down = len(down_dim) - self.nano
num_up = len(up_dim)
needs_h_edge_hf = any(x > 0 for x in down_num_blocks + up_num_blocks)
needs_v_edge_hf = num_down > 0 and isinstance(
pool_factory(pool, down_pool_dim[0]), BaseAttentivePool)
# Build MLPs that will be used to process handcrafted segment
# and edge features. These will be called before each
# DownNFuseStage and their output will be passed to
# DownNFuseStage and UpNFuseStage. For the special case of nano
# models, the first mlps will be run before the first Stage too
node_mlp = node_mlp if use_node_hf else None
self.node_mlps = _build_mlps(
node_mlp,
num_down + self.nano,
mlp_activation,
mlp_norm,
share_hf_mlps)
h_edge_mlp = h_edge_mlp if needs_h_edge_hf else None
self.h_edge_mlps = _build_mlps(
h_edge_mlp,
num_down + self.nano,
mlp_activation,
mlp_norm,
share_hf_mlps)
v_edge_mlp = v_edge_mlp if needs_v_edge_hf else None
self.v_edge_mlps = _build_mlps(
v_edge_mlp,
num_down,
mlp_activation,
mlp_norm,
share_hf_mlps)
# Module operating on Level-0 points in isolation
if self.nano:
self.first_stage = Stage(
down_dim[0],
num_blocks=down_num_blocks[0],
in_mlp=down_in_mlp[0],
out_mlp=down_out_mlp[0],
mlp_activation=mlp_activation,
mlp_norm=mlp_norm,
mlp_drop=down_mlp_drop[0],
num_heads=down_num_heads[0],
qk_dim=qk_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
in_rpe_dim=in_rpe_dim,
ffn_ratio=down_ffn_ratio[0],
residual_drop=down_residual_drop[0],
attn_drop=down_attn_drop[0],
drop_path=down_drop_path[0],
activation=activation,
norm=norm,
pre_norm=pre_norm,
no_sa=no_sa,
no_ffn=no_ffn,
k_rpe=k_rpe,
q_rpe=q_rpe,
v_rpe=v_rpe,
k_delta_rpe=k_delta_rpe,
q_delta_rpe=q_delta_rpe,
qk_share_rpe=qk_share_rpe,
q_on_minus_rpe=q_on_minus_rpe,
use_pos=use_pos,
use_diameter=use_diameter,
use_diameter_parent=use_diameter_parent,
blocks_share_rpe=blocks_share_rpe,
heads_share_rpe=heads_share_rpe)
else:
self.first_stage = PointStage(
point_mlp,
mlp_activation=mlp_activation,
mlp_norm=mlp_norm,
mlp_drop=point_drop,
use_pos=use_pos,
use_diameter_parent=use_diameter_parent)
# Operator to append the features such as the diameter or other
# handcrafted features to the NAG's features
self.feature_fusion = CatFusion()
# Transformer encoder (down) Stages operating on Level-i data
if num_down > 0:
# Build the RPE encoders here if shared across all stages
down_k_rpe = _build_shared_rpe_encoders(
k_rpe, num_down, 18, qk_dim, stages_share_rpe)
# If key and query RPEs share the same MLP, only the key MLP
# is preserved, to limit the number of model parameters
down_q_rpe = _build_shared_rpe_encoders(
q_rpe and not (k_rpe and qk_share_rpe), num_down, 18, qk_dim,
stages_share_rpe)
# Since the first value of each down_ parameter is used for
# the nano Stage (if self.nano=True), we artificially
# prepend None values to the rpe lists, so they have the
# same length as other down_ parameters
if self.nano:
down_k_rpe = [None] + down_k_rpe
down_q_rpe = [None] + down_q_rpe
self.down_stages = nn.ModuleList([
DownNFuseStage(
dim,
num_blocks=num_blocks,
in_mlp=in_mlp,
out_mlp=out_mlp,
mlp_activation=mlp_activation,
mlp_norm=mlp_norm,
mlp_drop=mlp_drop,
num_heads=num_heads,
qk_dim=qk_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
in_rpe_dim=in_rpe_dim,
ffn_ratio=ffn_ratio,
residual_drop=residual_drop,
attn_drop=attn_drop,
drop_path=drop_path,
activation=activation,
norm=norm,
pre_norm=pre_norm,
no_sa=no_sa,
no_ffn=no_ffn,
k_rpe=stage_k_rpe,
q_rpe=stage_q_rpe,
v_rpe=v_rpe,
k_delta_rpe=k_delta_rpe,
q_delta_rpe=q_delta_rpe,
qk_share_rpe=qk_share_rpe,
q_on_minus_rpe=q_on_minus_rpe,
pool=pool_factory(pool, pool_dim),
fusion=fusion,
use_pos=use_pos,
use_diameter=use_diameter,
use_diameter_parent=use_diameter_parent,
blocks_share_rpe=blocks_share_rpe,
heads_share_rpe=heads_share_rpe)
for
i_down,
(dim,
num_blocks,
in_mlp,
out_mlp,
mlp_drop,
num_heads,
ffn_ratio,
residual_drop,
attn_drop,
drop_path,
stage_k_rpe,
stage_q_rpe,
pool_dim)
in enumerate(zip(
down_dim,
down_num_blocks,
down_in_mlp,
down_out_mlp,
down_mlp_drop,
down_num_heads,
down_ffn_ratio,
down_residual_drop,
down_attn_drop,
down_drop_path,
down_k_rpe,
down_q_rpe,
down_pool_dim))
if i_down >= self.nano])
else:
self.down_stages = None
# Transformer decoder (up) Stages operating on Level-i data
if num_up > 0:
# Build the RPE encoder here if shared across all stages
up_k_rpe = _build_shared_rpe_encoders(
k_rpe, num_up, 18, qk_dim, stages_share_rpe)
# If key and query RPEs share the same MLP, only the key MLP
# is preserved, to limit the number of model parameters
up_q_rpe = _build_shared_rpe_encoders(
q_rpe and not (k_rpe and qk_share_rpe), num_up, 18, qk_dim,
stages_share_rpe)
self.up_stages = nn.ModuleList([
UpNFuseStage(
dim,
num_blocks=num_blocks,
in_mlp=in_mlp,
out_mlp=out_mlp,
mlp_activation=mlp_activation,
mlp_norm=mlp_norm,
mlp_drop=mlp_drop,
num_heads=num_heads,
qk_dim=qk_dim,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
in_rpe_dim=in_rpe_dim,
ffn_ratio=ffn_ratio,
residual_drop=residual_drop,
attn_drop=attn_drop,
drop_path=drop_path,
activation=activation,
norm=norm,
pre_norm=pre_norm,
no_sa=no_sa,
no_ffn=no_ffn,
k_rpe=stage_k_rpe,
q_rpe=stage_q_rpe,
v_rpe=v_rpe,
k_delta_rpe=k_delta_rpe,
q_delta_rpe=q_delta_rpe,
qk_share_rpe=qk_share_rpe,
q_on_minus_rpe=q_on_minus_rpe,
unpool=unpool,
fusion=fusion,
use_pos=use_pos,
use_diameter=use_diameter,
use_diameter_parent=use_diameter_parent,
blocks_share_rpe=blocks_share_rpe,
heads_share_rpe=heads_share_rpe)
for dim,
num_blocks,
in_mlp,
out_mlp,
mlp_drop,
num_heads,
ffn_ratio,
residual_drop,
attn_drop,
drop_path,
stage_k_rpe,
stage_q_rpe
in zip(
up_dim,
up_num_blocks,
up_in_mlp,
up_out_mlp,
up_mlp_drop,
up_num_heads,
up_ffn_ratio,
up_residual_drop,
up_attn_drop,
up_drop_path,
up_k_rpe,
up_q_rpe)])
else:
self.up_stages = None
assert self.num_up_stages > 0 or not self.output_stage_wise, \
"At least one up stage is needed for output_stage_wise=True"
assert bool(self.down_stages) != bool(self.up_stages) \
or self.num_down_stages >= self.num_up_stages, \
"The number of Up stages should be <= the number of Down " \
"stages."
assert self.nano or self.num_down_stages > self.num_up_stages, \
"The number of Up stages should be < the number of Down " \
"stages. That is to say, we do not want to output Level-0 " \
"features but at least Level-1."
@property
def num_down_stages(self):
return len(self.down_stages) if self.down_stages is not None else 0
@property
def num_up_stages(self):
return len(self.up_stages) if self.up_stages is not None else 0
@property
def out_dim(self):
if self.output_stage_wise:
out_dim = [stage.out_dim for stage in self.up_stages][::-1]
out_dim += [self.down_stages[-1].out_dim]
return out_dim
if self.up_stages is not None:
return self.up_stages[-1].out_dim
if self.down_stages is not None:
return self.down_stages[-1].out_dim
return self.first_stage.out_dim
def forward(self, nag):
# assert isinstance(nag, NAG)
# assert nag.num_levels >= 2
# assert nag.num_levels > self.num_down_stages
# TODO: this will need to be changed if we want FAST NANO
if self.nano:
nag = nag[1:]
# Apply the first MLPs on the handcrafted features
if self.nano:
if self.node_mlps is not None and self.node_mlps[0] is not None:
norm_index = nag[0].norm_index(mode=self.norm_mode)
nag[0].x = self.node_mlps[0](nag[0].x, batch=norm_index)
if self.h_edge_mlps is not None:
norm_index = nag[0].norm_index(mode=self.norm_mode)
norm_index = norm_index[nag[0].edge_index[0]]
nag[0].edge_attr = self.h_edge_mlps[0](
nag[0].edge_attr, batch=norm_index)
# Encode level-0 data
x, diameter = self.first_stage(
nag[0].x if self.use_node_hf else None,
nag[0].norm_index(mode=self.norm_mode),
pos=nag[0].pos,
diameter=None,
node_size=getattr(nag[0], 'node_size', None),
super_index=nag[0].super_index,
edge_index=nag[0].edge_index,
edge_attr=nag[0].edge_attr)
# Add the diameter to the next level's attributes
nag[1].diameter = diameter
# Iteratively encode level-1 and above
down_outputs = []
if self.nano:
down_outputs.append(x)
if self.down_stages is not None:
enum = enumerate(zip(
self.down_stages,
self.node_mlps[int(self.nano):],
self.h_edge_mlps[int(self.nano):],
self.v_edge_mlps))
for i_stage, (stage, node_mlp, h_edge_mlp, v_edge_mlp) in enum:
# Forward on the down stage and the corresponding NAG
# level
i_level = i_stage + 1
# Process handcrafted node and edge features. We need to
# do this here before those can be passed to the
# DownNFuseStage and, later on, to the UpNFuseStage
if node_mlp is not None:
norm_index = nag[i_level].norm_index(mode=self.norm_mode)
nag[i_level].x = node_mlp(nag[i_level].x, batch=norm_index)
if h_edge_mlp is not None:
norm_index = nag[i_level].norm_index(mode=self.norm_mode)
norm_index = norm_index[nag[i_level].edge_index[0]]
edge_attr = getattr(nag[i_level], 'edge_attr', None)
if edge_attr is not None:
nag[i_level].edge_attr = h_edge_mlp(
edge_attr, batch=norm_index)
if v_edge_mlp is not None:
norm_index = nag[i_level - 1].norm_index(mode=self.norm_mode)
v_edge_attr = getattr(nag[i_level], 'v_edge_attr', None)
if v_edge_attr is not None:
nag[i_level - 1].v_edge_attr = v_edge_mlp(
v_edge_attr, batch=norm_index)
# Forward on the DownNFuseStage
x, diameter = self._forward_down_stage(stage, nag, i_level, x)
down_outputs.append(x)
# End here if we reached the last NAG level
if i_level == nag.num_levels - 1:
continue
# Add the diameter to the next level's attributes
nag[i_level + 1].diameter = diameter
# Iteratively decode level-num_down_stages and below
up_outputs = []
if self.up_stages is not None:
for i_stage, stage in enumerate(self.up_stages):
i_level = self.num_down_stages - i_stage - 1
x_skip = down_outputs[-(2 + i_stage)]
x, _ = self._forward_up_stage(stage, nag, i_level, x, x_skip)
up_outputs.append(x)
# Different types of output signatures. For stage-wise output,
# return the output for each stage. For the Lmax level, we take
# the output of the innermost 'down_stage'. Finally, these
# outputs are sorted by order of increasing NAG level (from low
# to high)
if self.output_stage_wise:
out = [x] + up_outputs[::-1][1:] + [down_outputs[-1]]
return out
return x
def _forward_down_stage(self, stage, nag, i_level, x):
is_last_level = (i_level == nag.num_levels - 1)
x_handcrafted = nag[i_level].x if self.use_node_hf else None
return stage(
x_handcrafted,
x,
nag[i_level].norm_index(mode=self.norm_mode),
nag[i_level - 1].super_index,
pos=nag[i_level].pos,
diameter=nag[i_level].diameter,
node_size=nag[i_level].node_size,
super_index=nag[i_level].super_index if not is_last_level else None,
edge_index=nag[i_level].edge_index,
edge_attr=nag[i_level].edge_attr,
v_edge_attr=nag[i_level - 1].v_edge_attr,
num_super=nag[i_level].num_nodes)
def _forward_up_stage(self, stage, nag, i_level, x, x_skip):
x_handcrafted = nag[i_level].x if self.use_node_hf else None
return stage(
self.feature_fusion(x_skip, x_handcrafted),
x,
nag[i_level].norm_index(mode=self.norm_mode),
nag[i_level].super_index,
pos=nag[i_level].pos,
diameter=nag[i_level - self.nano].diameter,
node_size=nag[i_level].node_size,
super_index=nag[i_level].super_index,
edge_index=nag[i_level].edge_index,
edge_attr=nag[i_level].edge_attr)
def _build_shared_rpe_encoders(
rpe, num_stages, in_dim, out_dim, stages_share):
"""Local helper to build RPE encoders for spt. The main goal is to
make shared encoders construction easier.
Note that setting stages_share=True will make all stages, blocks and
heads use the same RPE encoder.
"""
if not isinstance(rpe, bool):
assert stages_share, \
"If anything else but a boolean is passed for the RPE encoder, " \
"this value will be passed to all Stages and `stages_share` " \
"should be set to True."
return [rpe] * num_stages
# If all stages share the same RPE encoder, all blocks and all heads
# too. We copy the same module instance to be shared across all
# stages and blocks
if stages_share and rpe:
return [nn.Linear(in_dim, out_dim)] * num_stages
return [rpe] * num_stages
def _build_mlps(layers, num_stage, activation, norm, shared):
if layers is None:
return [None] * num_stage
if shared:
return nn.ModuleList([
MLP(layers, activation=activation, norm=norm)] * num_stage)
return nn.ModuleList([
MLP(layers, activation=activation, norm=norm)
for _ in range(num_stage)])