| from torch import nn |
| from src.utils import listify_with_reference |
| from src.nn import Stage, PointStage, DownNFuseStage, UpNFuseStage, \ |
| BatchNorm, CatFusion, MLP, LayerNorm |
| from src.nn.pool import BaseAttentivePool |
| from src.nn.pool import pool_factory |
|
|
| __all__ = ['SPT'] |
|
|
|
|
| class SPT(nn.Module): |
| """Superpoint Transformer. A UNet-like architecture processing NAG. |
| |
| The architecture can be (roughly) summarized as: |
| |
| p_0, x_0 --------- PointStage |
| \ |
| p_1, x_1, e_1 -- DownNFuseStage_1 ------- UpNFuseStage_1 --> out_1 |
| \ | |
| p_2, x_2, e_2 -- DownNFuseStage_2 ------- UpNFuseStage_2 --> out_2 |
| \ | |
| ... ... |
| |
| Where: |
| - p_0: point positions |
| - x_0: input point features (handcrafted) |
| - p_i: node positions (i.e. superpoint centroid) at level i |
| - x_i: input node features (handcrafted superpoint features) at |
| level i |
| - e_i: input edge features (handcrafted horizontal superpoint graph |
| edge features) at level i |
| - out_i: node-wise output features at level i |
| |
| |
| :param point_mlp: List[int] |
| Channels for the input MLP of the `PointStage` |
| :param point_drop: float |
| Dropout rate for the last layer of the input and output MLPs |
| in `PointStage` |
| |
| :param nano: bool |
| If True, the `PointStage` will be removed and the model will |
| only operate on superpoints, without extracting features |
| from the points. This lightweight model saves compute and |
| memory, at the potential expense of high-resolution |
| reasoning |
| |
| :param down_dim: List[int], int |
| Feature dimension for each `DownNFuseStage` (i.e. not |
| including the `PointStage` when `nano=False`) |
| :param down_pool_dim: List[str], str |
| Pooling mechanism used for the down-pooling in each |
| `DownNFuseStage`. Supports 'max', 'min', 'mean', and 'sum'. |
| See `pool_factory()` for more |
| :param down_in_mlp: List[List[int]], List[int] |
| Channels for the input MLP of each `DownNFuseStage` |
| :param down_out_mlp: List[List[int]], List[int] |
| Channels for the output MLP of each `DownNFuseStage`. The |
| first channel for each stage must match with what is passed |
| in `down_dim` |
| :param down_mlp_drop: List[float], float |
| Dropout rate for the last layer of the input and output MLPs |
| of each `DownNFuseStage` |
| :param down_num_heads: List[int], int |
| Number of self-attention heads for each `DownNFuseStage |
| :param down_num_blocks: List[int], int |
| Number of self-attention blocks for each `DownNFuseStage |
| :param down_ffn_ratio: List[float], float |
| Multiplicative factor for computing the dimension of the |
| `FFN` inverted bottleneck, for each `DownNFuseStage. See |
| `TransformerBlock` |
| :param down_residual_drop: List[float], float |
| Dropout on the output self-attention features for each |
| `DownNFuseStage`. See `TransformerBlock` |
| :param down_attn_drop: List[float], float |
| Dropout on the self-attention weights for each |
| `DownNFuseStage`. See `TransformerBlock` |
| :param down_drop_path: List[float], float |
| Dropout on the residual paths for each `DownNFuseStage`. See |
| `TransformerBlock` |
| |
| :param up_dim: List[int], int |
| Feature dimension for each `UpNFuseStage` |
| :param up_in_mlp: List[List[int]], List[int] |
| Channels for the input MLP of each `UpNFuseStage` |
| :param up_out_mlp: List[List[int]], List[int] |
| Channels for the output MLP of each `UpNFuseStage`. The |
| first channel for each stage must match with what is passed |
| in `up_dim` |
| :param up_mlp_drop: List[float], float |
| Dropout rate for the last layer of the input and output MLPs |
| of each `UpNFuseStage` |
| :param up_num_heads: List[int], int |
| Number of self-attention heads for each `UpNFuseStage |
| :param up_num_blocks: List[int], int |
| Number of self-attention blocks for each `UpNFuseStage |
| :param up_ffn_ratio: List[float], float |
| Multiplicative factor for computing the dimension of the |
| `FFN` inverted bottleneck, for each `UpNFuseStage. See |
| `TransformerBlock` |
| :param up_residual_drop: List[float], float |
| Dropout on the output self-attention features for each |
| `UpNFuseStage`. See `TransformerBlock` |
| :param up_attn_drop: List[float], float |
| Dropout on the self-attention weights for each |
| `UpNFuseStage`. See `TransformerBlock` |
| :param up_drop_path: List[float], float |
| Dropout on the residual paths for each `UpNFuseStage`. See |
| `TransformerBlock` |
| |
| :param node_mlp: List[int] |
| Channels for the MLPs that will encode handcrafted node |
| (i.e. segment, superpoint) features. These will be called |
| before each `DownNFuseStage` and their output will be |
| concatenated to any already-existing features and passed |
| to `DownNFuseStage` and `UpNFuseStage`. For the special case |
| the `nano=True` model, the first MLP will be run before the |
| first `Stage` too |
| :param h_edge_mlp: List[int] |
| Channels for the MLPs that will encode handcrafted |
| horizontal edge (i.e. edges in the superpoint adjacency |
| graph at each partition level) features. These will be |
| called before each `DownNFuseStage` and their output will be |
| passed as `edge_attr` to `DownNFuseStage` and `UpNFuseStage` |
| :param v_edge_mlp: List[int] |
| Channels for the MLPs that will encode handcrafted |
| vertical edge (i.e. edges connecting nodes to their parent |
| in the above partition level) features. These will be |
| called before each `DownNFuseStage` and their output will be |
| passed as `v_edge_attr` to `DownNFuseStage` and |
| `UpNFuseStage` |
| :param mlp_activation: nn.Module |
| Activation function used for all MLPs throughout the |
| architecture |
| :param mlp_norm: n.Module |
| Normalization function for all MLPs throughout the |
| architecture |
| :param qk_dim: int |
| Dimension of the queries and keys. See `SelfAttentionBlock` |
| :param qkv_bias: bool |
| Whether the linear layers producing queries, keys, and |
| values should have a bias. See `SelfAttentionBlock` |
| :param qk_scale: str |
| Scaling applied to the query*key product before the softmax. |
| More specifically, one may want to normalize the query-key |
| compatibilities based on the number of dimensions (referred |
| to as 'd' here) as in a vanilla Transformer implementation, |
| or based on the number of neighbors each node has in the |
| attention graph (referred to as 'g' here). If nothing is |
| specified the scaling will be `1 / (sqrt(d) * sqrt(g))`, |
| which is equivalent to passing `'d.g'`. Passing `'d+g'` will |
| yield `1 / (sqrt(d) + sqrt(g))`. Meanwhile, passing 'd' will |
| yield `1 / sqrt(d)`, and passing `'g'` will yield |
| `1 / sqrt(g)`. See `SelfAttentionBlock` |
| :param in_rpe_dim: |
| :param activation: nn.Module |
| Activation function used in the `FFN` modules. See |
| `TransformerBlock` |
| :param norm: nn.Module |
| Normalization function for the `FFN` module. See |
| `TransformerBlock` |
| :param pre_norm: bool |
| Whether the normalization should be applied before or after |
| the `SelfAttentionBlock` and `FFN` in the residual branches. |
| See`TransformerBlock` |
| :param no_sa: bool |
| Whether a self-attention residual branch should be used at |
| all. See`TransformerBlock` |
| :param no_ffn: bool |
| Whether a feed-forward residual branch should be used at |
| all. See`TransformerBlock` |
| :param k_rpe: bool |
| Whether keys should receive relative positional encodings |
| computed from edge features. See `SelfAttentionBlock` |
| :param q_rpe: bool |
| Whether queries should receive relative positional encodings |
| computed from edge features. See `SelfAttentionBlock` |
| :param v_rpe: bool |
| Whether values should receive relative positional encodings |
| computed from edge features. See `SelfAttentionBlock` |
| :param k_delta_rpe: bool |
| Whether keys should receive relative positional encodings |
| computed from the difference between source and target node |
| features. See `SelfAttentionBlock` |
| :param q_delta_rpe: bool |
| Whether queries should receive relative positional encodings |
| computed from the difference between source and target node |
| features. See `SelfAttentionBlock` |
| :param qk_share_rpe: bool |
| Whether queries and keys should use the same parameters for |
| building relative positional encodings. See |
| `SelfAttentionBlock` |
| :param q_on_minus_rpe: bool |
| Whether relative positional encodings for queries should be |
| computed on the opposite of features used for keys. This allows, |
| for instance, to break the symmetry when `qk_share_rpe` but we |
| want relative positional encodings to capture different meanings |
| for keys and queries. See `SelfAttentionBlock` |
| :param share_hf_mlps: bool |
| Whether stages should share the MLPs for encoding |
| handcrafted node, horizontal edge, and vertical edge |
| features |
| :param stages_share_rpe: bool |
| Whether all `Stage`s should share the same parameters for |
| building relative positional encodings |
| :param blocks_share_rpe: bool |
| Whether all the `TransformerBlock` in the same `Stage` |
| should share the same parameters for building relative |
| positional encodings |
| :param heads_share_rpe: bool |
| Whether attention heads should share the same parameters for |
| building relative positional encodings |
| |
| :param use_pos: bool |
| Whether the node's position (normalized with `UnitSphereNorm`) |
| should be concatenated to the features. See `Stage` |
| :param use_node_hf: bool |
| Whether handcrafted node (i.e. segment, superpoint) features |
| should be used at all. If False, `node_mlp` will be ignored |
| :param use_diameter: bool |
| Whether the node's diameter (currently estimated with |
| `UnitSphereNorm`) should be concatenated to the node features. |
| See `Stage` |
| :param use_diameter_parent: bool |
| Whether the node's parent diameter (currently estimated with |
| `UnitSphereNorm`) should be concatenated to the node features. |
| See `Stage` |
| :param pool: str, nn.Module |
| Pooling mechanism for `DownNFuseStage`s. Supports 'max', |
| 'min', 'mean', 'sum' for string arguments. |
| See `pool_factory()` for more |
| :param unpool: str |
| Unpooling mechanism for `UpNFuseStage`s. Only supports |
| 'index' for now |
| :param fusion: str |
| Fusion mechanism used in `DownNFuseStage` and `UpNFuseStage` |
| to merge node features from different branches. Supports |
| 'cat', 'residual', 'first', 'second'. See `fusion_factory()` |
| for more |
| :param norm_mode: str |
| Indexing mode used for feature normalization. This will be |
| passed to `Data.norm_index()`. 'graph' will normalize |
| features per graph (i.e. per cloud, i.e. per batch item). |
| 'node' will normalize per node (i.e. per point). 'segment' |
| will normalize per segment (i.e. per cluster) |
| :param output_stage_wise: bool |
| If True, the output contain the features for each node of |
| each partition 1+ level. IF False, only the features for the |
| partition level 1 will be returned. Note we do not compute |
| the features for level 0, since the entire goal of this |
| superpoint-based reasoning is to mitigate compute and memory |
| by circumventing the need to manipulate such full-resolution |
| objects |
| """ |
|
|
| def __init__( |
| self, |
| |
| point_mlp=None, |
| point_drop=None, |
| |
| nano=False, |
| |
| down_dim=None, |
| down_pool_dim=None, |
| down_in_mlp=None, |
| down_out_mlp=None, |
| down_mlp_drop=None, |
| down_num_heads=1, |
| down_num_blocks=0, |
| down_ffn_ratio=4, |
| down_residual_drop=None, |
| down_attn_drop=None, |
| down_drop_path=None, |
| |
| up_dim=None, |
| up_in_mlp=None, |
| up_out_mlp=None, |
| up_mlp_drop=None, |
| up_num_heads=1, |
| up_num_blocks=0, |
| up_ffn_ratio=4, |
| up_residual_drop=None, |
| up_attn_drop=None, |
| up_drop_path=None, |
| |
| node_mlp=None, |
| h_edge_mlp=None, |
| v_edge_mlp=None, |
| mlp_activation=nn.LeakyReLU(), |
| mlp_norm=BatchNorm, |
| qk_dim=8, |
| qkv_bias=True, |
| qk_scale=None, |
| in_rpe_dim=18, |
| activation=nn.LeakyReLU(), |
| norm=LayerNorm, |
| pre_norm=True, |
| no_sa=False, |
| no_ffn=False, |
| k_rpe=False, |
| q_rpe=False, |
| v_rpe=False, |
| k_delta_rpe=False, |
| q_delta_rpe=False, |
| qk_share_rpe=False, |
| q_on_minus_rpe=False, |
| share_hf_mlps=False, |
| stages_share_rpe=False, |
| blocks_share_rpe=False, |
| heads_share_rpe=False, |
| |
| use_pos=True, |
| use_node_hf=True, |
| use_diameter=False, |
| use_diameter_parent=False, |
| pool='max', |
| unpool='index', |
| fusion='cat', |
| norm_mode='graph', |
| output_stage_wise=False): |
| super().__init__() |
|
|
| self.nano = nano |
| self.use_pos = use_pos |
| self.use_node_hf = use_node_hf |
| self.use_diameter = use_diameter |
| self.use_diameter_parent = use_diameter_parent |
| self.norm_mode = norm_mode |
| self.stages_share_rpe = stages_share_rpe |
| self.blocks_share_rpe = blocks_share_rpe |
| self.heads_share_rpe = heads_share_rpe |
| self.output_stage_wise = output_stage_wise |
|
|
| |
| ( |
| down_dim, |
| down_pool_dim, |
| down_in_mlp, |
| down_out_mlp, |
| down_mlp_drop, |
| down_num_heads, |
| down_num_blocks, |
| down_ffn_ratio, |
| down_residual_drop, |
| down_attn_drop, |
| down_drop_path |
| ) = listify_with_reference( |
| down_dim, |
| down_pool_dim, |
| down_in_mlp, |
| down_out_mlp, |
| down_mlp_drop, |
| down_num_heads, |
| down_num_blocks, |
| down_ffn_ratio, |
| down_residual_drop, |
| down_attn_drop, |
| down_drop_path) |
|
|
| ( |
| up_dim, |
| up_in_mlp, |
| up_out_mlp, |
| up_mlp_drop, |
| up_num_heads, |
| up_num_blocks, |
| up_ffn_ratio, |
| up_residual_drop, |
| up_attn_drop, |
| up_drop_path |
| ) = listify_with_reference( |
| up_dim, |
| up_in_mlp, |
| up_out_mlp, |
| up_mlp_drop, |
| up_num_heads, |
| up_num_blocks, |
| up_ffn_ratio, |
| up_residual_drop, |
| up_attn_drop, |
| up_drop_path) |
|
|
| |
| num_down = len(down_dim) - self.nano |
| num_up = len(up_dim) |
| needs_h_edge_hf = any(x > 0 for x in down_num_blocks + up_num_blocks) |
| needs_v_edge_hf = num_down > 0 and isinstance( |
| pool_factory(pool, down_pool_dim[0]), BaseAttentivePool) |
|
|
| |
| |
| |
| |
| |
| node_mlp = node_mlp if use_node_hf else None |
| self.node_mlps = _build_mlps( |
| node_mlp, |
| num_down + self.nano, |
| mlp_activation, |
| mlp_norm, |
| share_hf_mlps) |
|
|
| h_edge_mlp = h_edge_mlp if needs_h_edge_hf else None |
| self.h_edge_mlps = _build_mlps( |
| h_edge_mlp, |
| num_down + self.nano, |
| mlp_activation, |
| mlp_norm, |
| share_hf_mlps) |
|
|
| v_edge_mlp = v_edge_mlp if needs_v_edge_hf else None |
| self.v_edge_mlps = _build_mlps( |
| v_edge_mlp, |
| num_down, |
| mlp_activation, |
| mlp_norm, |
| share_hf_mlps) |
|
|
| |
| if self.nano: |
| self.first_stage = Stage( |
| down_dim[0], |
| num_blocks=down_num_blocks[0], |
| in_mlp=down_in_mlp[0], |
| out_mlp=down_out_mlp[0], |
| mlp_activation=mlp_activation, |
| mlp_norm=mlp_norm, |
| mlp_drop=down_mlp_drop[0], |
| num_heads=down_num_heads[0], |
| qk_dim=qk_dim, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| in_rpe_dim=in_rpe_dim, |
| ffn_ratio=down_ffn_ratio[0], |
| residual_drop=down_residual_drop[0], |
| attn_drop=down_attn_drop[0], |
| drop_path=down_drop_path[0], |
| activation=activation, |
| norm=norm, |
| pre_norm=pre_norm, |
| no_sa=no_sa, |
| no_ffn=no_ffn, |
| k_rpe=k_rpe, |
| q_rpe=q_rpe, |
| v_rpe=v_rpe, |
| k_delta_rpe=k_delta_rpe, |
| q_delta_rpe=q_delta_rpe, |
| qk_share_rpe=qk_share_rpe, |
| q_on_minus_rpe=q_on_minus_rpe, |
| use_pos=use_pos, |
| use_diameter=use_diameter, |
| use_diameter_parent=use_diameter_parent, |
| blocks_share_rpe=blocks_share_rpe, |
| heads_share_rpe=heads_share_rpe) |
| else: |
| self.first_stage = PointStage( |
| point_mlp, |
| mlp_activation=mlp_activation, |
| mlp_norm=mlp_norm, |
| mlp_drop=point_drop, |
| use_pos=use_pos, |
| use_diameter_parent=use_diameter_parent) |
|
|
| |
| |
| self.feature_fusion = CatFusion() |
|
|
| |
| if num_down > 0: |
|
|
| |
| down_k_rpe = _build_shared_rpe_encoders( |
| k_rpe, num_down, 18, qk_dim, stages_share_rpe) |
|
|
| |
| |
| down_q_rpe = _build_shared_rpe_encoders( |
| q_rpe and not (k_rpe and qk_share_rpe), num_down, 18, qk_dim, |
| stages_share_rpe) |
|
|
| |
| |
| |
| |
| if self.nano: |
| down_k_rpe = [None] + down_k_rpe |
| down_q_rpe = [None] + down_q_rpe |
|
|
| self.down_stages = nn.ModuleList([ |
| DownNFuseStage( |
| dim, |
| num_blocks=num_blocks, |
| in_mlp=in_mlp, |
| out_mlp=out_mlp, |
| mlp_activation=mlp_activation, |
| mlp_norm=mlp_norm, |
| mlp_drop=mlp_drop, |
| num_heads=num_heads, |
| qk_dim=qk_dim, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| in_rpe_dim=in_rpe_dim, |
| ffn_ratio=ffn_ratio, |
| residual_drop=residual_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path, |
| activation=activation, |
| norm=norm, |
| pre_norm=pre_norm, |
| no_sa=no_sa, |
| no_ffn=no_ffn, |
| k_rpe=stage_k_rpe, |
| q_rpe=stage_q_rpe, |
| v_rpe=v_rpe, |
| k_delta_rpe=k_delta_rpe, |
| q_delta_rpe=q_delta_rpe, |
| qk_share_rpe=qk_share_rpe, |
| q_on_minus_rpe=q_on_minus_rpe, |
| pool=pool_factory(pool, pool_dim), |
| fusion=fusion, |
| use_pos=use_pos, |
| use_diameter=use_diameter, |
| use_diameter_parent=use_diameter_parent, |
| blocks_share_rpe=blocks_share_rpe, |
| heads_share_rpe=heads_share_rpe) |
| for |
| i_down, |
| (dim, |
| num_blocks, |
| in_mlp, |
| out_mlp, |
| mlp_drop, |
| num_heads, |
| ffn_ratio, |
| residual_drop, |
| attn_drop, |
| drop_path, |
| stage_k_rpe, |
| stage_q_rpe, |
| pool_dim) |
| in enumerate(zip( |
| down_dim, |
| down_num_blocks, |
| down_in_mlp, |
| down_out_mlp, |
| down_mlp_drop, |
| down_num_heads, |
| down_ffn_ratio, |
| down_residual_drop, |
| down_attn_drop, |
| down_drop_path, |
| down_k_rpe, |
| down_q_rpe, |
| down_pool_dim)) |
| if i_down >= self.nano]) |
| else: |
| self.down_stages = None |
|
|
| |
| if num_up > 0: |
|
|
| |
| up_k_rpe = _build_shared_rpe_encoders( |
| k_rpe, num_up, 18, qk_dim, stages_share_rpe) |
|
|
| |
| |
| up_q_rpe = _build_shared_rpe_encoders( |
| q_rpe and not (k_rpe and qk_share_rpe), num_up, 18, qk_dim, |
| stages_share_rpe) |
|
|
| self.up_stages = nn.ModuleList([ |
| UpNFuseStage( |
| dim, |
| num_blocks=num_blocks, |
| in_mlp=in_mlp, |
| out_mlp=out_mlp, |
| mlp_activation=mlp_activation, |
| mlp_norm=mlp_norm, |
| mlp_drop=mlp_drop, |
| num_heads=num_heads, |
| qk_dim=qk_dim, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| in_rpe_dim=in_rpe_dim, |
| ffn_ratio=ffn_ratio, |
| residual_drop=residual_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path, |
| activation=activation, |
| norm=norm, |
| pre_norm=pre_norm, |
| no_sa=no_sa, |
| no_ffn=no_ffn, |
| k_rpe=stage_k_rpe, |
| q_rpe=stage_q_rpe, |
| v_rpe=v_rpe, |
| k_delta_rpe=k_delta_rpe, |
| q_delta_rpe=q_delta_rpe, |
| qk_share_rpe=qk_share_rpe, |
| q_on_minus_rpe=q_on_minus_rpe, |
| unpool=unpool, |
| fusion=fusion, |
| use_pos=use_pos, |
| use_diameter=use_diameter, |
| use_diameter_parent=use_diameter_parent, |
| blocks_share_rpe=blocks_share_rpe, |
| heads_share_rpe=heads_share_rpe) |
| for dim, |
| num_blocks, |
| in_mlp, |
| out_mlp, |
| mlp_drop, |
| num_heads, |
| ffn_ratio, |
| residual_drop, |
| attn_drop, |
| drop_path, |
| stage_k_rpe, |
| stage_q_rpe |
| in zip( |
| up_dim, |
| up_num_blocks, |
| up_in_mlp, |
| up_out_mlp, |
| up_mlp_drop, |
| up_num_heads, |
| up_ffn_ratio, |
| up_residual_drop, |
| up_attn_drop, |
| up_drop_path, |
| up_k_rpe, |
| up_q_rpe)]) |
| else: |
| self.up_stages = None |
|
|
| assert self.num_up_stages > 0 or not self.output_stage_wise, \ |
| "At least one up stage is needed for output_stage_wise=True" |
|
|
| assert bool(self.down_stages) != bool(self.up_stages) \ |
| or self.num_down_stages >= self.num_up_stages, \ |
| "The number of Up stages should be <= the number of Down " \ |
| "stages." |
| assert self.nano or self.num_down_stages > self.num_up_stages, \ |
| "The number of Up stages should be < the number of Down " \ |
| "stages. That is to say, we do not want to output Level-0 " \ |
| "features but at least Level-1." |
|
|
| @property |
| def num_down_stages(self): |
| return len(self.down_stages) if self.down_stages is not None else 0 |
|
|
| @property |
| def num_up_stages(self): |
| return len(self.up_stages) if self.up_stages is not None else 0 |
|
|
| @property |
| def out_dim(self): |
| if self.output_stage_wise: |
| out_dim = [stage.out_dim for stage in self.up_stages][::-1] |
| out_dim += [self.down_stages[-1].out_dim] |
| return out_dim |
| if self.up_stages is not None: |
| return self.up_stages[-1].out_dim |
| if self.down_stages is not None: |
| return self.down_stages[-1].out_dim |
| return self.first_stage.out_dim |
|
|
| def forward(self, nag): |
| |
| |
| |
|
|
| |
| if self.nano: |
| nag = nag[1:] |
|
|
| |
| if self.nano: |
| if self.node_mlps is not None and self.node_mlps[0] is not None: |
| norm_index = nag[0].norm_index(mode=self.norm_mode) |
| nag[0].x = self.node_mlps[0](nag[0].x, batch=norm_index) |
| if self.h_edge_mlps is not None: |
| norm_index = nag[0].norm_index(mode=self.norm_mode) |
| norm_index = norm_index[nag[0].edge_index[0]] |
| nag[0].edge_attr = self.h_edge_mlps[0]( |
| nag[0].edge_attr, batch=norm_index) |
|
|
| |
| x, diameter = self.first_stage( |
| nag[0].x if self.use_node_hf else None, |
| nag[0].norm_index(mode=self.norm_mode), |
| pos=nag[0].pos, |
| diameter=None, |
| node_size=getattr(nag[0], 'node_size', None), |
| super_index=nag[0].super_index, |
| edge_index=nag[0].edge_index, |
| edge_attr=nag[0].edge_attr) |
|
|
| |
| nag[1].diameter = diameter |
|
|
| |
| down_outputs = [] |
| if self.nano: |
| down_outputs.append(x) |
| if self.down_stages is not None: |
|
|
| enum = enumerate(zip( |
| self.down_stages, |
| self.node_mlps[int(self.nano):], |
| self.h_edge_mlps[int(self.nano):], |
| self.v_edge_mlps)) |
|
|
| for i_stage, (stage, node_mlp, h_edge_mlp, v_edge_mlp) in enum: |
|
|
| |
| |
| i_level = i_stage + 1 |
|
|
| |
| |
| |
| if node_mlp is not None: |
| norm_index = nag[i_level].norm_index(mode=self.norm_mode) |
| nag[i_level].x = node_mlp(nag[i_level].x, batch=norm_index) |
| if h_edge_mlp is not None: |
| norm_index = nag[i_level].norm_index(mode=self.norm_mode) |
| norm_index = norm_index[nag[i_level].edge_index[0]] |
| edge_attr = getattr(nag[i_level], 'edge_attr', None) |
| if edge_attr is not None: |
| nag[i_level].edge_attr = h_edge_mlp( |
| edge_attr, batch=norm_index) |
| if v_edge_mlp is not None: |
| norm_index = nag[i_level - 1].norm_index(mode=self.norm_mode) |
| v_edge_attr = getattr(nag[i_level], 'v_edge_attr', None) |
| if v_edge_attr is not None: |
| nag[i_level - 1].v_edge_attr = v_edge_mlp( |
| v_edge_attr, batch=norm_index) |
|
|
| |
| x, diameter = self._forward_down_stage(stage, nag, i_level, x) |
| down_outputs.append(x) |
|
|
| |
| if i_level == nag.num_levels - 1: |
| continue |
|
|
| |
| nag[i_level + 1].diameter = diameter |
|
|
| |
| up_outputs = [] |
| if self.up_stages is not None: |
| for i_stage, stage in enumerate(self.up_stages): |
| i_level = self.num_down_stages - i_stage - 1 |
| x_skip = down_outputs[-(2 + i_stage)] |
| x, _ = self._forward_up_stage(stage, nag, i_level, x, x_skip) |
| up_outputs.append(x) |
|
|
| |
| |
| |
| |
| |
| if self.output_stage_wise: |
| out = [x] + up_outputs[::-1][1:] + [down_outputs[-1]] |
| return out |
|
|
| return x |
|
|
| def _forward_down_stage(self, stage, nag, i_level, x): |
| is_last_level = (i_level == nag.num_levels - 1) |
| x_handcrafted = nag[i_level].x if self.use_node_hf else None |
| return stage( |
| x_handcrafted, |
| x, |
| nag[i_level].norm_index(mode=self.norm_mode), |
| nag[i_level - 1].super_index, |
| pos=nag[i_level].pos, |
| diameter=nag[i_level].diameter, |
| node_size=nag[i_level].node_size, |
| super_index=nag[i_level].super_index if not is_last_level else None, |
| edge_index=nag[i_level].edge_index, |
| edge_attr=nag[i_level].edge_attr, |
| v_edge_attr=nag[i_level - 1].v_edge_attr, |
| num_super=nag[i_level].num_nodes) |
|
|
| def _forward_up_stage(self, stage, nag, i_level, x, x_skip): |
| x_handcrafted = nag[i_level].x if self.use_node_hf else None |
| return stage( |
| self.feature_fusion(x_skip, x_handcrafted), |
| x, |
| nag[i_level].norm_index(mode=self.norm_mode), |
| nag[i_level].super_index, |
| pos=nag[i_level].pos, |
| diameter=nag[i_level - self.nano].diameter, |
| node_size=nag[i_level].node_size, |
| super_index=nag[i_level].super_index, |
| edge_index=nag[i_level].edge_index, |
| edge_attr=nag[i_level].edge_attr) |
|
|
|
|
| def _build_shared_rpe_encoders( |
| rpe, num_stages, in_dim, out_dim, stages_share): |
| """Local helper to build RPE encoders for spt. The main goal is to |
| make shared encoders construction easier. |
| |
| Note that setting stages_share=True will make all stages, blocks and |
| heads use the same RPE encoder. |
| """ |
| if not isinstance(rpe, bool): |
| assert stages_share, \ |
| "If anything else but a boolean is passed for the RPE encoder, " \ |
| "this value will be passed to all Stages and `stages_share` " \ |
| "should be set to True." |
| return [rpe] * num_stages |
|
|
| |
| |
| |
| if stages_share and rpe: |
| return [nn.Linear(in_dim, out_dim)] * num_stages |
|
|
| return [rpe] * num_stages |
|
|
|
|
| def _build_mlps(layers, num_stage, activation, norm, shared): |
| if layers is None: |
| return [None] * num_stage |
|
|
| if shared: |
| return nn.ModuleList([ |
| MLP(layers, activation=activation, norm=norm)] * num_stage) |
|
|
| return nn.ModuleList([ |
| MLP(layers, activation=activation, norm=norm) |
| for _ in range(num_stage)]) |
|
|