|
|
from torch import nn |
|
|
from src.utils import listify_with_reference |
|
|
from src.nn import Stage, PointStage, DownNFuseStage, UpNFuseStage, \ |
|
|
BatchNorm, CatFusion, MLP, LayerNorm |
|
|
from src.nn.pool import BaseAttentivePool |
|
|
from src.nn.pool import pool_factory |
|
|
|
|
|
__all__ = ['SPT'] |
|
|
|
|
|
|
|
|
class SPT(nn.Module): |
|
|
"""Superpoint Transformer. A UNet-like architecture processing NAG. |
|
|
|
|
|
The architecture can be (roughly) summarized as: |
|
|
|
|
|
p_0, x_0 --------- PointStage |
|
|
\ |
|
|
p_1, x_1, e_1 -- DownNFuseStage_1 ------- UpNFuseStage_1 --> out_1 |
|
|
\ | |
|
|
p_2, x_2, e_2 -- DownNFuseStage_2 ------- UpNFuseStage_2 --> out_2 |
|
|
\ | |
|
|
... ... |
|
|
|
|
|
Where: |
|
|
- p_0: point positions |
|
|
- x_0: input point features (handcrafted) |
|
|
- p_i: node positions (i.e. superpoint centroid) at level i |
|
|
- x_i: input node features (handcrafted superpoint features) at |
|
|
level i |
|
|
- e_i: input edge features (handcrafted horizontal superpoint graph |
|
|
edge features) at level i |
|
|
- out_i: node-wise output features at level i |
|
|
|
|
|
|
|
|
:param point_mlp: List[int] |
|
|
Channels for the input MLP of the `PointStage` |
|
|
:param point_drop: float |
|
|
Dropout rate for the last layer of the input and output MLPs |
|
|
in `PointStage` |
|
|
|
|
|
:param nano: bool |
|
|
If True, the `PointStage` will be removed and the model will |
|
|
only operate on superpoints, without extracting features |
|
|
from the points. This lightweight model saves compute and |
|
|
memory, at the potential expense of high-resolution |
|
|
reasoning |
|
|
|
|
|
:param down_dim: List[int], int |
|
|
Feature dimension for each `DownNFuseStage` (i.e. not |
|
|
including the `PointStage` when `nano=False`) |
|
|
:param down_pool_dim: List[str], str |
|
|
Pooling mechanism used for the down-pooling in each |
|
|
`DownNFuseStage`. Supports 'max', 'min', 'mean', and 'sum'. |
|
|
See `pool_factory()` for more |
|
|
:param down_in_mlp: List[List[int]], List[int] |
|
|
Channels for the input MLP of each `DownNFuseStage` |
|
|
:param down_out_mlp: List[List[int]], List[int] |
|
|
Channels for the output MLP of each `DownNFuseStage`. The |
|
|
first channel for each stage must match with what is passed |
|
|
in `down_dim` |
|
|
:param down_mlp_drop: List[float], float |
|
|
Dropout rate for the last layer of the input and output MLPs |
|
|
of each `DownNFuseStage` |
|
|
:param down_num_heads: List[int], int |
|
|
Number of self-attention heads for each `DownNFuseStage |
|
|
:param down_num_blocks: List[int], int |
|
|
Number of self-attention blocks for each `DownNFuseStage |
|
|
:param down_ffn_ratio: List[float], float |
|
|
Multiplicative factor for computing the dimension of the |
|
|
`FFN` inverted bottleneck, for each `DownNFuseStage. See |
|
|
`TransformerBlock` |
|
|
:param down_residual_drop: List[float], float |
|
|
Dropout on the output self-attention features for each |
|
|
`DownNFuseStage`. See `TransformerBlock` |
|
|
:param down_attn_drop: List[float], float |
|
|
Dropout on the self-attention weights for each |
|
|
`DownNFuseStage`. See `TransformerBlock` |
|
|
:param down_drop_path: List[float], float |
|
|
Dropout on the residual paths for each `DownNFuseStage`. See |
|
|
`TransformerBlock` |
|
|
|
|
|
:param up_dim: List[int], int |
|
|
Feature dimension for each `UpNFuseStage` |
|
|
:param up_in_mlp: List[List[int]], List[int] |
|
|
Channels for the input MLP of each `UpNFuseStage` |
|
|
:param up_out_mlp: List[List[int]], List[int] |
|
|
Channels for the output MLP of each `UpNFuseStage`. The |
|
|
first channel for each stage must match with what is passed |
|
|
in `up_dim` |
|
|
:param up_mlp_drop: List[float], float |
|
|
Dropout rate for the last layer of the input and output MLPs |
|
|
of each `UpNFuseStage` |
|
|
:param up_num_heads: List[int], int |
|
|
Number of self-attention heads for each `UpNFuseStage |
|
|
:param up_num_blocks: List[int], int |
|
|
Number of self-attention blocks for each `UpNFuseStage |
|
|
:param up_ffn_ratio: List[float], float |
|
|
Multiplicative factor for computing the dimension of the |
|
|
`FFN` inverted bottleneck, for each `UpNFuseStage. See |
|
|
`TransformerBlock` |
|
|
:param up_residual_drop: List[float], float |
|
|
Dropout on the output self-attention features for each |
|
|
`UpNFuseStage`. See `TransformerBlock` |
|
|
:param up_attn_drop: List[float], float |
|
|
Dropout on the self-attention weights for each |
|
|
`UpNFuseStage`. See `TransformerBlock` |
|
|
:param up_drop_path: List[float], float |
|
|
Dropout on the residual paths for each `UpNFuseStage`. See |
|
|
`TransformerBlock` |
|
|
|
|
|
:param node_mlp: List[int] |
|
|
Channels for the MLPs that will encode handcrafted node |
|
|
(i.e. segment, superpoint) features. These will be called |
|
|
before each `DownNFuseStage` and their output will be |
|
|
concatenated to any already-existing features and passed |
|
|
to `DownNFuseStage` and `UpNFuseStage`. For the special case |
|
|
the `nano=True` model, the first MLP will be run before the |
|
|
first `Stage` too |
|
|
:param h_edge_mlp: List[int] |
|
|
Channels for the MLPs that will encode handcrafted |
|
|
horizontal edge (i.e. edges in the superpoint adjacency |
|
|
graph at each partition level) features. These will be |
|
|
called before each `DownNFuseStage` and their output will be |
|
|
passed as `edge_attr` to `DownNFuseStage` and `UpNFuseStage` |
|
|
:param v_edge_mlp: List[int] |
|
|
Channels for the MLPs that will encode handcrafted |
|
|
vertical edge (i.e. edges connecting nodes to their parent |
|
|
in the above partition level) features. These will be |
|
|
called before each `DownNFuseStage` and their output will be |
|
|
passed as `v_edge_attr` to `DownNFuseStage` and |
|
|
`UpNFuseStage` |
|
|
:param mlp_activation: nn.Module |
|
|
Activation function used for all MLPs throughout the |
|
|
architecture |
|
|
:param mlp_norm: n.Module |
|
|
Normalization function for all MLPs throughout the |
|
|
architecture |
|
|
:param qk_dim: int |
|
|
Dimension of the queries and keys. See `SelfAttentionBlock` |
|
|
:param qkv_bias: bool |
|
|
Whether the linear layers producing queries, keys, and |
|
|
values should have a bias. See `SelfAttentionBlock` |
|
|
:param qk_scale: str |
|
|
Scaling applied to the query*key product before the softmax. |
|
|
More specifically, one may want to normalize the query-key |
|
|
compatibilities based on the number of dimensions (referred |
|
|
to as 'd' here) as in a vanilla Transformer implementation, |
|
|
or based on the number of neighbors each node has in the |
|
|
attention graph (referred to as 'g' here). If nothing is |
|
|
specified the scaling will be `1 / (sqrt(d) * sqrt(g))`, |
|
|
which is equivalent to passing `'d.g'`. Passing `'d+g'` will |
|
|
yield `1 / (sqrt(d) + sqrt(g))`. Meanwhile, passing 'd' will |
|
|
yield `1 / sqrt(d)`, and passing `'g'` will yield |
|
|
`1 / sqrt(g)`. See `SelfAttentionBlock` |
|
|
:param in_rpe_dim: |
|
|
:param activation: nn.Module |
|
|
Activation function used in the `FFN` modules. See |
|
|
`TransformerBlock` |
|
|
:param norm: nn.Module |
|
|
Normalization function for the `FFN` module. See |
|
|
`TransformerBlock` |
|
|
:param pre_norm: bool |
|
|
Whether the normalization should be applied before or after |
|
|
the `SelfAttentionBlock` and `FFN` in the residual branches. |
|
|
See`TransformerBlock` |
|
|
:param no_sa: bool |
|
|
Whether a self-attention residual branch should be used at |
|
|
all. See`TransformerBlock` |
|
|
:param no_ffn: bool |
|
|
Whether a feed-forward residual branch should be used at |
|
|
all. See`TransformerBlock` |
|
|
:param k_rpe: bool |
|
|
Whether keys should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param q_rpe: bool |
|
|
Whether queries should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param v_rpe: bool |
|
|
Whether values should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param k_delta_rpe: bool |
|
|
Whether keys should receive relative positional encodings |
|
|
computed from the difference between source and target node |
|
|
features. See `SelfAttentionBlock` |
|
|
:param q_delta_rpe: bool |
|
|
Whether queries should receive relative positional encodings |
|
|
computed from the difference between source and target node |
|
|
features. See `SelfAttentionBlock` |
|
|
:param qk_share_rpe: bool |
|
|
Whether queries and keys should use the same parameters for |
|
|
building relative positional encodings. See |
|
|
`SelfAttentionBlock` |
|
|
:param q_on_minus_rpe: bool |
|
|
Whether relative positional encodings for queries should be |
|
|
computed on the opposite of features used for keys. This allows, |
|
|
for instance, to break the symmetry when `qk_share_rpe` but we |
|
|
want relative positional encodings to capture different meanings |
|
|
for keys and queries. See `SelfAttentionBlock` |
|
|
:param share_hf_mlps: bool |
|
|
Whether stages should share the MLPs for encoding |
|
|
handcrafted node, horizontal edge, and vertical edge |
|
|
features |
|
|
:param stages_share_rpe: bool |
|
|
Whether all `Stage`s should share the same parameters for |
|
|
building relative positional encodings |
|
|
:param blocks_share_rpe: bool |
|
|
Whether all the `TransformerBlock` in the same `Stage` |
|
|
should share the same parameters for building relative |
|
|
positional encodings |
|
|
:param heads_share_rpe: bool |
|
|
Whether attention heads should share the same parameters for |
|
|
building relative positional encodings |
|
|
|
|
|
:param use_pos: bool |
|
|
Whether the node's position (normalized with `UnitSphereNorm`) |
|
|
should be concatenated to the features. See `Stage` |
|
|
:param use_node_hf: bool |
|
|
Whether handcrafted node (i.e. segment, superpoint) features |
|
|
should be used at all. If False, `node_mlp` will be ignored |
|
|
:param use_diameter: bool |
|
|
Whether the node's diameter (currently estimated with |
|
|
`UnitSphereNorm`) should be concatenated to the node features. |
|
|
See `Stage` |
|
|
:param use_diameter_parent: bool |
|
|
Whether the node's parent diameter (currently estimated with |
|
|
`UnitSphereNorm`) should be concatenated to the node features. |
|
|
See `Stage` |
|
|
:param pool: str, nn.Module |
|
|
Pooling mechanism for `DownNFuseStage`s. Supports 'max', |
|
|
'min', 'mean', 'sum' for string arguments. |
|
|
See `pool_factory()` for more |
|
|
:param unpool: str |
|
|
Unpooling mechanism for `UpNFuseStage`s. Only supports |
|
|
'index' for now |
|
|
:param fusion: str |
|
|
Fusion mechanism used in `DownNFuseStage` and `UpNFuseStage` |
|
|
to merge node features from different branches. Supports |
|
|
'cat', 'residual', 'first', 'second'. See `fusion_factory()` |
|
|
for more |
|
|
:param norm_mode: str |
|
|
Indexing mode used for feature normalization. This will be |
|
|
passed to `Data.norm_index()`. 'graph' will normalize |
|
|
features per graph (i.e. per cloud, i.e. per batch item). |
|
|
'node' will normalize per node (i.e. per point). 'segment' |
|
|
will normalize per segment (i.e. per cluster) |
|
|
:param output_stage_wise: bool |
|
|
If True, the output contain the features for each node of |
|
|
each partition 1+ level. IF False, only the features for the |
|
|
partition level 1 will be returned. Note we do not compute |
|
|
the features for level 0, since the entire goal of this |
|
|
superpoint-based reasoning is to mitigate compute and memory |
|
|
by circumventing the need to manipulate such full-resolution |
|
|
objects |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
point_mlp=None, |
|
|
point_drop=None, |
|
|
|
|
|
nano=False, |
|
|
|
|
|
down_dim=None, |
|
|
down_pool_dim=None, |
|
|
down_in_mlp=None, |
|
|
down_out_mlp=None, |
|
|
down_mlp_drop=None, |
|
|
down_num_heads=1, |
|
|
down_num_blocks=0, |
|
|
down_ffn_ratio=4, |
|
|
down_residual_drop=None, |
|
|
down_attn_drop=None, |
|
|
down_drop_path=None, |
|
|
|
|
|
up_dim=None, |
|
|
up_in_mlp=None, |
|
|
up_out_mlp=None, |
|
|
up_mlp_drop=None, |
|
|
up_num_heads=1, |
|
|
up_num_blocks=0, |
|
|
up_ffn_ratio=4, |
|
|
up_residual_drop=None, |
|
|
up_attn_drop=None, |
|
|
up_drop_path=None, |
|
|
|
|
|
node_mlp=None, |
|
|
h_edge_mlp=None, |
|
|
v_edge_mlp=None, |
|
|
mlp_activation=nn.LeakyReLU(), |
|
|
mlp_norm=BatchNorm, |
|
|
qk_dim=8, |
|
|
qkv_bias=True, |
|
|
qk_scale=None, |
|
|
in_rpe_dim=18, |
|
|
activation=nn.LeakyReLU(), |
|
|
norm=LayerNorm, |
|
|
pre_norm=True, |
|
|
no_sa=False, |
|
|
no_ffn=False, |
|
|
k_rpe=False, |
|
|
q_rpe=False, |
|
|
v_rpe=False, |
|
|
k_delta_rpe=False, |
|
|
q_delta_rpe=False, |
|
|
qk_share_rpe=False, |
|
|
q_on_minus_rpe=False, |
|
|
share_hf_mlps=False, |
|
|
stages_share_rpe=False, |
|
|
blocks_share_rpe=False, |
|
|
heads_share_rpe=False, |
|
|
|
|
|
use_pos=True, |
|
|
use_node_hf=True, |
|
|
use_diameter=False, |
|
|
use_diameter_parent=False, |
|
|
pool='max', |
|
|
unpool='index', |
|
|
fusion='cat', |
|
|
norm_mode='graph', |
|
|
output_stage_wise=False): |
|
|
super().__init__() |
|
|
|
|
|
self.nano = nano |
|
|
self.use_pos = use_pos |
|
|
self.use_node_hf = use_node_hf |
|
|
self.use_diameter = use_diameter |
|
|
self.use_diameter_parent = use_diameter_parent |
|
|
self.norm_mode = norm_mode |
|
|
self.stages_share_rpe = stages_share_rpe |
|
|
self.blocks_share_rpe = blocks_share_rpe |
|
|
self.heads_share_rpe = heads_share_rpe |
|
|
self.output_stage_wise = output_stage_wise |
|
|
|
|
|
|
|
|
( |
|
|
down_dim, |
|
|
down_pool_dim, |
|
|
down_in_mlp, |
|
|
down_out_mlp, |
|
|
down_mlp_drop, |
|
|
down_num_heads, |
|
|
down_num_blocks, |
|
|
down_ffn_ratio, |
|
|
down_residual_drop, |
|
|
down_attn_drop, |
|
|
down_drop_path |
|
|
) = listify_with_reference( |
|
|
down_dim, |
|
|
down_pool_dim, |
|
|
down_in_mlp, |
|
|
down_out_mlp, |
|
|
down_mlp_drop, |
|
|
down_num_heads, |
|
|
down_num_blocks, |
|
|
down_ffn_ratio, |
|
|
down_residual_drop, |
|
|
down_attn_drop, |
|
|
down_drop_path) |
|
|
|
|
|
( |
|
|
up_dim, |
|
|
up_in_mlp, |
|
|
up_out_mlp, |
|
|
up_mlp_drop, |
|
|
up_num_heads, |
|
|
up_num_blocks, |
|
|
up_ffn_ratio, |
|
|
up_residual_drop, |
|
|
up_attn_drop, |
|
|
up_drop_path |
|
|
) = listify_with_reference( |
|
|
up_dim, |
|
|
up_in_mlp, |
|
|
up_out_mlp, |
|
|
up_mlp_drop, |
|
|
up_num_heads, |
|
|
up_num_blocks, |
|
|
up_ffn_ratio, |
|
|
up_residual_drop, |
|
|
up_attn_drop, |
|
|
up_drop_path) |
|
|
|
|
|
|
|
|
num_down = len(down_dim) - self.nano |
|
|
num_up = len(up_dim) |
|
|
needs_h_edge_hf = any(x > 0 for x in down_num_blocks + up_num_blocks) |
|
|
needs_v_edge_hf = num_down > 0 and isinstance( |
|
|
pool_factory(pool, down_pool_dim[0]), BaseAttentivePool) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node_mlp = node_mlp if use_node_hf else None |
|
|
self.node_mlps = _build_mlps( |
|
|
node_mlp, |
|
|
num_down + self.nano, |
|
|
mlp_activation, |
|
|
mlp_norm, |
|
|
share_hf_mlps) |
|
|
|
|
|
h_edge_mlp = h_edge_mlp if needs_h_edge_hf else None |
|
|
self.h_edge_mlps = _build_mlps( |
|
|
h_edge_mlp, |
|
|
num_down + self.nano, |
|
|
mlp_activation, |
|
|
mlp_norm, |
|
|
share_hf_mlps) |
|
|
|
|
|
v_edge_mlp = v_edge_mlp if needs_v_edge_hf else None |
|
|
self.v_edge_mlps = _build_mlps( |
|
|
v_edge_mlp, |
|
|
num_down, |
|
|
mlp_activation, |
|
|
mlp_norm, |
|
|
share_hf_mlps) |
|
|
|
|
|
|
|
|
if self.nano: |
|
|
self.first_stage = Stage( |
|
|
down_dim[0], |
|
|
num_blocks=down_num_blocks[0], |
|
|
in_mlp=down_in_mlp[0], |
|
|
out_mlp=down_out_mlp[0], |
|
|
mlp_activation=mlp_activation, |
|
|
mlp_norm=mlp_norm, |
|
|
mlp_drop=down_mlp_drop[0], |
|
|
num_heads=down_num_heads[0], |
|
|
qk_dim=qk_dim, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
in_rpe_dim=in_rpe_dim, |
|
|
ffn_ratio=down_ffn_ratio[0], |
|
|
residual_drop=down_residual_drop[0], |
|
|
attn_drop=down_attn_drop[0], |
|
|
drop_path=down_drop_path[0], |
|
|
activation=activation, |
|
|
norm=norm, |
|
|
pre_norm=pre_norm, |
|
|
no_sa=no_sa, |
|
|
no_ffn=no_ffn, |
|
|
k_rpe=k_rpe, |
|
|
q_rpe=q_rpe, |
|
|
v_rpe=v_rpe, |
|
|
k_delta_rpe=k_delta_rpe, |
|
|
q_delta_rpe=q_delta_rpe, |
|
|
qk_share_rpe=qk_share_rpe, |
|
|
q_on_minus_rpe=q_on_minus_rpe, |
|
|
use_pos=use_pos, |
|
|
use_diameter=use_diameter, |
|
|
use_diameter_parent=use_diameter_parent, |
|
|
blocks_share_rpe=blocks_share_rpe, |
|
|
heads_share_rpe=heads_share_rpe) |
|
|
else: |
|
|
self.first_stage = PointStage( |
|
|
point_mlp, |
|
|
mlp_activation=mlp_activation, |
|
|
mlp_norm=mlp_norm, |
|
|
mlp_drop=point_drop, |
|
|
use_pos=use_pos, |
|
|
use_diameter_parent=use_diameter_parent) |
|
|
|
|
|
|
|
|
|
|
|
self.feature_fusion = CatFusion() |
|
|
|
|
|
|
|
|
if num_down > 0: |
|
|
|
|
|
|
|
|
down_k_rpe = _build_shared_rpe_encoders( |
|
|
k_rpe, num_down, 18, qk_dim, stages_share_rpe) |
|
|
|
|
|
|
|
|
|
|
|
down_q_rpe = _build_shared_rpe_encoders( |
|
|
q_rpe and not (k_rpe and qk_share_rpe), num_down, 18, qk_dim, |
|
|
stages_share_rpe) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.nano: |
|
|
down_k_rpe = [None] + down_k_rpe |
|
|
down_q_rpe = [None] + down_q_rpe |
|
|
|
|
|
self.down_stages = nn.ModuleList([ |
|
|
DownNFuseStage( |
|
|
dim, |
|
|
num_blocks=num_blocks, |
|
|
in_mlp=in_mlp, |
|
|
out_mlp=out_mlp, |
|
|
mlp_activation=mlp_activation, |
|
|
mlp_norm=mlp_norm, |
|
|
mlp_drop=mlp_drop, |
|
|
num_heads=num_heads, |
|
|
qk_dim=qk_dim, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
in_rpe_dim=in_rpe_dim, |
|
|
ffn_ratio=ffn_ratio, |
|
|
residual_drop=residual_drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path, |
|
|
activation=activation, |
|
|
norm=norm, |
|
|
pre_norm=pre_norm, |
|
|
no_sa=no_sa, |
|
|
no_ffn=no_ffn, |
|
|
k_rpe=stage_k_rpe, |
|
|
q_rpe=stage_q_rpe, |
|
|
v_rpe=v_rpe, |
|
|
k_delta_rpe=k_delta_rpe, |
|
|
q_delta_rpe=q_delta_rpe, |
|
|
qk_share_rpe=qk_share_rpe, |
|
|
q_on_minus_rpe=q_on_minus_rpe, |
|
|
pool=pool_factory(pool, pool_dim), |
|
|
fusion=fusion, |
|
|
use_pos=use_pos, |
|
|
use_diameter=use_diameter, |
|
|
use_diameter_parent=use_diameter_parent, |
|
|
blocks_share_rpe=blocks_share_rpe, |
|
|
heads_share_rpe=heads_share_rpe) |
|
|
for |
|
|
i_down, |
|
|
(dim, |
|
|
num_blocks, |
|
|
in_mlp, |
|
|
out_mlp, |
|
|
mlp_drop, |
|
|
num_heads, |
|
|
ffn_ratio, |
|
|
residual_drop, |
|
|
attn_drop, |
|
|
drop_path, |
|
|
stage_k_rpe, |
|
|
stage_q_rpe, |
|
|
pool_dim) |
|
|
in enumerate(zip( |
|
|
down_dim, |
|
|
down_num_blocks, |
|
|
down_in_mlp, |
|
|
down_out_mlp, |
|
|
down_mlp_drop, |
|
|
down_num_heads, |
|
|
down_ffn_ratio, |
|
|
down_residual_drop, |
|
|
down_attn_drop, |
|
|
down_drop_path, |
|
|
down_k_rpe, |
|
|
down_q_rpe, |
|
|
down_pool_dim)) |
|
|
if i_down >= self.nano]) |
|
|
else: |
|
|
self.down_stages = None |
|
|
|
|
|
|
|
|
if num_up > 0: |
|
|
|
|
|
|
|
|
up_k_rpe = _build_shared_rpe_encoders( |
|
|
k_rpe, num_up, 18, qk_dim, stages_share_rpe) |
|
|
|
|
|
|
|
|
|
|
|
up_q_rpe = _build_shared_rpe_encoders( |
|
|
q_rpe and not (k_rpe and qk_share_rpe), num_up, 18, qk_dim, |
|
|
stages_share_rpe) |
|
|
|
|
|
self.up_stages = nn.ModuleList([ |
|
|
UpNFuseStage( |
|
|
dim, |
|
|
num_blocks=num_blocks, |
|
|
in_mlp=in_mlp, |
|
|
out_mlp=out_mlp, |
|
|
mlp_activation=mlp_activation, |
|
|
mlp_norm=mlp_norm, |
|
|
mlp_drop=mlp_drop, |
|
|
num_heads=num_heads, |
|
|
qk_dim=qk_dim, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
in_rpe_dim=in_rpe_dim, |
|
|
ffn_ratio=ffn_ratio, |
|
|
residual_drop=residual_drop, |
|
|
attn_drop=attn_drop, |
|
|
drop_path=drop_path, |
|
|
activation=activation, |
|
|
norm=norm, |
|
|
pre_norm=pre_norm, |
|
|
no_sa=no_sa, |
|
|
no_ffn=no_ffn, |
|
|
k_rpe=stage_k_rpe, |
|
|
q_rpe=stage_q_rpe, |
|
|
v_rpe=v_rpe, |
|
|
k_delta_rpe=k_delta_rpe, |
|
|
q_delta_rpe=q_delta_rpe, |
|
|
qk_share_rpe=qk_share_rpe, |
|
|
q_on_minus_rpe=q_on_minus_rpe, |
|
|
unpool=unpool, |
|
|
fusion=fusion, |
|
|
use_pos=use_pos, |
|
|
use_diameter=use_diameter, |
|
|
use_diameter_parent=use_diameter_parent, |
|
|
blocks_share_rpe=blocks_share_rpe, |
|
|
heads_share_rpe=heads_share_rpe) |
|
|
for dim, |
|
|
num_blocks, |
|
|
in_mlp, |
|
|
out_mlp, |
|
|
mlp_drop, |
|
|
num_heads, |
|
|
ffn_ratio, |
|
|
residual_drop, |
|
|
attn_drop, |
|
|
drop_path, |
|
|
stage_k_rpe, |
|
|
stage_q_rpe |
|
|
in zip( |
|
|
up_dim, |
|
|
up_num_blocks, |
|
|
up_in_mlp, |
|
|
up_out_mlp, |
|
|
up_mlp_drop, |
|
|
up_num_heads, |
|
|
up_ffn_ratio, |
|
|
up_residual_drop, |
|
|
up_attn_drop, |
|
|
up_drop_path, |
|
|
up_k_rpe, |
|
|
up_q_rpe)]) |
|
|
else: |
|
|
self.up_stages = None |
|
|
|
|
|
assert self.num_up_stages > 0 or not self.output_stage_wise, \ |
|
|
"At least one up stage is needed for output_stage_wise=True" |
|
|
|
|
|
assert bool(self.down_stages) != bool(self.up_stages) \ |
|
|
or self.num_down_stages >= self.num_up_stages, \ |
|
|
"The number of Up stages should be <= the number of Down " \ |
|
|
"stages." |
|
|
assert self.nano or self.num_down_stages > self.num_up_stages, \ |
|
|
"The number of Up stages should be < the number of Down " \ |
|
|
"stages. That is to say, we do not want to output Level-0 " \ |
|
|
"features but at least Level-1." |
|
|
|
|
|
@property |
|
|
def num_down_stages(self): |
|
|
return len(self.down_stages) if self.down_stages is not None else 0 |
|
|
|
|
|
@property |
|
|
def num_up_stages(self): |
|
|
return len(self.up_stages) if self.up_stages is not None else 0 |
|
|
|
|
|
@property |
|
|
def out_dim(self): |
|
|
if self.output_stage_wise: |
|
|
out_dim = [stage.out_dim for stage in self.up_stages][::-1] |
|
|
out_dim += [self.down_stages[-1].out_dim] |
|
|
return out_dim |
|
|
if self.up_stages is not None: |
|
|
return self.up_stages[-1].out_dim |
|
|
if self.down_stages is not None: |
|
|
return self.down_stages[-1].out_dim |
|
|
return self.first_stage.out_dim |
|
|
|
|
|
def forward(self, nag): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.nano: |
|
|
nag = nag[1:] |
|
|
|
|
|
|
|
|
if self.nano: |
|
|
if self.node_mlps is not None and self.node_mlps[0] is not None: |
|
|
norm_index = nag[0].norm_index(mode=self.norm_mode) |
|
|
nag[0].x = self.node_mlps[0](nag[0].x, batch=norm_index) |
|
|
if self.h_edge_mlps is not None: |
|
|
norm_index = nag[0].norm_index(mode=self.norm_mode) |
|
|
norm_index = norm_index[nag[0].edge_index[0]] |
|
|
nag[0].edge_attr = self.h_edge_mlps[0]( |
|
|
nag[0].edge_attr, batch=norm_index) |
|
|
|
|
|
|
|
|
x, diameter = self.first_stage( |
|
|
nag[0].x if self.use_node_hf else None, |
|
|
nag[0].norm_index(mode=self.norm_mode), |
|
|
pos=nag[0].pos, |
|
|
diameter=None, |
|
|
node_size=getattr(nag[0], 'node_size', None), |
|
|
super_index=nag[0].super_index, |
|
|
edge_index=nag[0].edge_index, |
|
|
edge_attr=nag[0].edge_attr) |
|
|
|
|
|
|
|
|
nag[1].diameter = diameter |
|
|
|
|
|
|
|
|
down_outputs = [] |
|
|
if self.nano: |
|
|
down_outputs.append(x) |
|
|
if self.down_stages is not None: |
|
|
|
|
|
enum = enumerate(zip( |
|
|
self.down_stages, |
|
|
self.node_mlps[int(self.nano):], |
|
|
self.h_edge_mlps[int(self.nano):], |
|
|
self.v_edge_mlps)) |
|
|
|
|
|
for i_stage, (stage, node_mlp, h_edge_mlp, v_edge_mlp) in enum: |
|
|
|
|
|
|
|
|
|
|
|
i_level = i_stage + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if node_mlp is not None: |
|
|
norm_index = nag[i_level].norm_index(mode=self.norm_mode) |
|
|
nag[i_level].x = node_mlp(nag[i_level].x, batch=norm_index) |
|
|
if h_edge_mlp is not None: |
|
|
norm_index = nag[i_level].norm_index(mode=self.norm_mode) |
|
|
norm_index = norm_index[nag[i_level].edge_index[0]] |
|
|
edge_attr = getattr(nag[i_level], 'edge_attr', None) |
|
|
if edge_attr is not None: |
|
|
nag[i_level].edge_attr = h_edge_mlp( |
|
|
edge_attr, batch=norm_index) |
|
|
if v_edge_mlp is not None: |
|
|
norm_index = nag[i_level - 1].norm_index(mode=self.norm_mode) |
|
|
v_edge_attr = getattr(nag[i_level], 'v_edge_attr', None) |
|
|
if v_edge_attr is not None: |
|
|
nag[i_level - 1].v_edge_attr = v_edge_mlp( |
|
|
v_edge_attr, batch=norm_index) |
|
|
|
|
|
|
|
|
x, diameter = self._forward_down_stage(stage, nag, i_level, x) |
|
|
down_outputs.append(x) |
|
|
|
|
|
|
|
|
if i_level == nag.num_levels - 1: |
|
|
continue |
|
|
|
|
|
|
|
|
nag[i_level + 1].diameter = diameter |
|
|
|
|
|
|
|
|
up_outputs = [] |
|
|
if self.up_stages is not None: |
|
|
for i_stage, stage in enumerate(self.up_stages): |
|
|
i_level = self.num_down_stages - i_stage - 1 |
|
|
x_skip = down_outputs[-(2 + i_stage)] |
|
|
x, _ = self._forward_up_stage(stage, nag, i_level, x, x_skip) |
|
|
up_outputs.append(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.output_stage_wise: |
|
|
out = [x] + up_outputs[::-1][1:] + [down_outputs[-1]] |
|
|
return out |
|
|
|
|
|
return x |
|
|
|
|
|
def _forward_down_stage(self, stage, nag, i_level, x): |
|
|
is_last_level = (i_level == nag.num_levels - 1) |
|
|
x_handcrafted = nag[i_level].x if self.use_node_hf else None |
|
|
return stage( |
|
|
x_handcrafted, |
|
|
x, |
|
|
nag[i_level].norm_index(mode=self.norm_mode), |
|
|
nag[i_level - 1].super_index, |
|
|
pos=nag[i_level].pos, |
|
|
diameter=nag[i_level].diameter, |
|
|
node_size=nag[i_level].node_size, |
|
|
super_index=nag[i_level].super_index if not is_last_level else None, |
|
|
edge_index=nag[i_level].edge_index, |
|
|
edge_attr=nag[i_level].edge_attr, |
|
|
v_edge_attr=nag[i_level - 1].v_edge_attr, |
|
|
num_super=nag[i_level].num_nodes) |
|
|
|
|
|
def _forward_up_stage(self, stage, nag, i_level, x, x_skip): |
|
|
x_handcrafted = nag[i_level].x if self.use_node_hf else None |
|
|
return stage( |
|
|
self.feature_fusion(x_skip, x_handcrafted), |
|
|
x, |
|
|
nag[i_level].norm_index(mode=self.norm_mode), |
|
|
nag[i_level].super_index, |
|
|
pos=nag[i_level].pos, |
|
|
diameter=nag[i_level - self.nano].diameter, |
|
|
node_size=nag[i_level].node_size, |
|
|
super_index=nag[i_level].super_index, |
|
|
edge_index=nag[i_level].edge_index, |
|
|
edge_attr=nag[i_level].edge_attr) |
|
|
|
|
|
|
|
|
def _build_shared_rpe_encoders( |
|
|
rpe, num_stages, in_dim, out_dim, stages_share): |
|
|
"""Local helper to build RPE encoders for spt. The main goal is to |
|
|
make shared encoders construction easier. |
|
|
|
|
|
Note that setting stages_share=True will make all stages, blocks and |
|
|
heads use the same RPE encoder. |
|
|
""" |
|
|
if not isinstance(rpe, bool): |
|
|
assert stages_share, \ |
|
|
"If anything else but a boolean is passed for the RPE encoder, " \ |
|
|
"this value will be passed to all Stages and `stages_share` " \ |
|
|
"should be set to True." |
|
|
return [rpe] * num_stages |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stages_share and rpe: |
|
|
return [nn.Linear(in_dim, out_dim)] * num_stages |
|
|
|
|
|
return [rpe] * num_stages |
|
|
|
|
|
|
|
|
def _build_mlps(layers, num_stage, activation, norm, shared): |
|
|
if layers is None: |
|
|
return [None] * num_stage |
|
|
|
|
|
if shared: |
|
|
return nn.ModuleList([ |
|
|
MLP(layers, activation=activation, norm=norm)] * num_stage) |
|
|
|
|
|
return nn.ModuleList([ |
|
|
MLP(layers, activation=activation, norm=norm) |
|
|
for _ in range(num_stage)]) |
|
|
|