import torch from torch import nn from torch_geometric.nn.aggr import SumAggregation from torch_geometric.nn.aggr import MeanAggregation from torch_geometric.nn.aggr import MaxAggregation from torch_geometric.nn.aggr import MinAggregation from torch_scatter import scatter_sum from torch_geometric.utils import softmax from src.utils.nn import init_weights, LearnableParameter, build_qk_scale_func __all__ = [ 'pool_factory', 'SumPool', 'MeanPool', 'MaxPool', 'MinPool', 'AttentivePool', 'AttentivePoolWithLearntQueries'] def pool_factory(pool, *args, **kwargs): """Build a Pool module from string or from an existing module. This helper is intended to be used as a helper in spt and Stage constructors. """ if isinstance(pool, (AggregationPoolMixIn, BaseAttentivePool)): return pool if pool == 'max': return MaxPool() if pool == 'min': return MinPool() if pool == 'mean': return MeanPool() if pool == 'sum': return SumPool() return pool(*args, **kwargs) class AggregationPoolMixIn: """MixIn class to convert torch-geometric Aggregation modules into Pool modules with our desired forward signature. :param x_child: Tensor of shape (Nc, Cc) Node features for the children nodes :param x_parent: Any Not used for Aggregation :param index: LongTensor of shape (Nc) Indices indicating the parent of each for each child node :param edge_attr: Any Not used for Aggregation :param num_pool: int Number of parent nodes Nc. If not provided, will be inferred from `index.max() + 1` :return: """ def __call__(self, x_child, x_parent, index, edge_attr=None, num_pool=None): return super().__call__(x_child, index=index, dim_size=num_pool) class SumPool(AggregationPoolMixIn, SumAggregation): pass class MeanPool(AggregationPoolMixIn, MeanAggregation): pass class MaxPool(AggregationPoolMixIn, MaxAggregation): pass class MinPool(AggregationPoolMixIn, MinAggregation): pass class BaseAttentivePool(nn.Module): """Base class for attentive pooling classes. This class is not intended to be instantiated, but avoids duplicating code between similar child classes, which are expected to implement: - `_get_query()` """ # TODO: this module could be used for pooling from one segment level # to the next. But requires defining how. With QKV paradigm ? Then # how to define Q for superpoints ? from max-pooled/mean-pooled # features ? from handcrafted features ? If not QKV, simply have a # FFN predict (multi-headed) attention scores to be softmaxed ? How # to guide pooling from the above level (same pb as for qkv) ? # TODO: see torch_geometric SoftmaxAggregation and # AttentionalAggregation for possibilities. Among which, a # learnable softmax temperature def __init__( self, dim=None, num_heads=1, in_dim=None, out_dim=None, qkv_bias=True, qk_dim=8, qk_scale=None, attn_drop=None, drop=None, in_rpe_dim=9, k_rpe=False, q_rpe=False, v_rpe=False, heads_share_rpe=False): super().__init__() assert dim % num_heads == 0, f"dim must be a multiple of num_heads" self.dim = dim self.num_heads = num_heads self.qk_dim = qk_dim self.qk_scale = build_qk_scale_func(dim, num_heads, qk_scale) self.heads_share_rpe = heads_share_rpe self.kv = nn.Linear(dim, qk_dim * num_heads + dim, bias=qkv_bias) # Build the RPE encoders, with the option of sharing weights # across all heads rpe_dim = qk_dim if heads_share_rpe else qk_dim * num_heads if not isinstance(k_rpe, bool): self.k_rpe = k_rpe else: self.k_rpe = nn.Linear(in_rpe_dim, rpe_dim) if k_rpe else None if not isinstance(q_rpe, bool): self.q_rpe = q_rpe else: self.q_rpe = nn.Linear(in_rpe_dim, rpe_dim) if q_rpe else None if v_rpe: raise NotImplementedError self.in_proj = nn.Linear(in_dim, dim) if in_dim is not None else None self.out_proj = nn.Linear(dim, out_dim) if out_dim is not None else None self.attn_drop = nn.Dropout(attn_drop) \ if attn_drop is not None and attn_drop > 0 else None self.out_drop = nn.Dropout(drop) \ if drop is not None and drop > 0 else None def forward( self, x_child, x_parent, index, edge_attr=None, num_pool=None): """ :param x_child: Tensor of shape (Nc, Cc) Node features for the children nodes :param x_parent: Tensor of shape (Np, Cp) Node features for the parent nodes :param index: LongTensor of shape (Nc) Indices indicating the parent of each for each child node :param edge_attr: FloatTensor or shape (Nc, F) Edge attributes for relative pose encoding :param num_pool: int Number of parent nodes Nc. If not provided, will be inferred from the shape of x_parent :return: """ Nc = x_child.shape[0] Np = x_parent.shape[0] if num_pool is None else num_pool H = self.num_heads D = self.qk_dim DH = D * H # Optional linear projection of features if self.in_proj is not None: x_child = self.in_proj(x_child) # Compute queries from parent features q = self._get_query(x_parent) # [Np, DH] # Compute keys and values from child features kv = self.kv(x_child) # [Nc, DH + C] # Expand queries and separate keys and values q = q[index].view(Nc, H, D) # [Nc, H, D] k = kv[:, :DH].view(Nc, H, D) # [Nc, H, D] v = kv[:, DH:].view(Nc, H, -1) # [Nc, H, C // H] # Apply scaling on the queries q = q * self.qk_scale(index) # TODO: add the relative positional encodings to the # compatibilities here # - k_rpe, q_rpe, v_rpe # - pos difference, absolute distance, squared distance, centroid distance, edge distance, ... # - with/out edge attributes # - mlp (L-LN-A-L), learnable lookup table (see Stratified Transformer) # - scalar rpe, vector rpe (see Stratified Transformer) if self.k_rpe is not None: rpe = self.k_rpe(edge_attr) # Expand RPE to all heads if heads share the RPE encoder if self.heads_share_rpe: rpe = rpe.repeat(1, H) k = k + rpe.view(Nc, H, -1) if self.q_rpe is not None: rpe = self.q_rpe(edge_attr) # Expand RPE to all heads if heads share the RPE encoder if self.heads_share_rpe: rpe = rpe.repeat(1, H) q = q + rpe.view(Nc, H, -1) # Compute compatibility scores from the query-key products compat = torch.einsum('nhd, nhd -> nh', q, k) # [Nc, H] # Compute the attention scores with scaled softmax attn = softmax(compat, index=index, dim=0, num_nodes=Np) # [Nc, H] # Optional attention dropout if self.attn_drop is not None: attn = self.attn_drop(attn) # Apply the attention on the values x = (v * attn.unsqueeze(-1)).view(Nc, self.dim) # [Nc, C] x = scatter_sum(x, index, dim=0, dim_size=Np) # [Np, C] # Optional linear projection of features if self.out_proj is not None: x = self.out_proj(x) # [Np, out_dim] # Optional dropout on projection of features if self.out_drop is not None: x = self.out_drop(x) # [Np, C] or [Np, out_dim] return x def _get_query(self, x_parent): """Overwrite this method to implement the attentive pooling. :param x_parent: Tensor of shape (Np, Cp) Node features for the parent nodes :return: Tensor of shape (Np, D * H) """ raise NotImplementedError def extra_repr(self) -> str: return f'dim={self.dim}, num_heads={self.num_heads}' class AttentivePool(BaseAttentivePool): def __init__( self, dim=None, q_in_dim=None, num_heads=1, in_dim=None, out_dim=None, qkv_bias=True, qk_dim=8, qk_scale=None, attn_drop=None, drop=None, in_rpe_dim=9, k_rpe=False, q_rpe=False, v_rpe=False, heads_share_rpe=False): super().__init__( dim=dim, num_heads=num_heads, in_dim=in_dim, out_dim=out_dim, qkv_bias=qkv_bias, qk_dim=qk_dim, qk_scale=qk_scale, attn_drop=attn_drop, drop=drop, in_rpe_dim=in_rpe_dim, k_rpe=k_rpe, q_rpe=q_rpe, v_rpe=v_rpe, heads_share_rpe=heads_share_rpe) # Queries will be built from input parent feature self.q = nn.Linear(q_in_dim, qk_dim * num_heads, bias=qkv_bias) # TODO: use FFN heare to deal with handcrafted features def _get_query(self, x_parent): """Build queries from input parent features :param x_parent: Tensor of shape (Np, Cp) Node features for the parent nodes :return: Tensor of shape (Np, D * H) """ return self.q(x_parent) # [Np, DH] class AttentivePoolWithLearntQueries(BaseAttentivePool): def __init__( self, dim=None, num_heads=1, in_dim=None, out_dim=None, qkv_bias=True, qk_dim=8, qk_scale=None, attn_drop=None, drop=None, in_rpe_dim=18, k_rpe=False, q_rpe=False, v_rpe=False, heads_share_rpe=False): super().__init__( dim=dim, num_heads=num_heads, in_dim=in_dim, out_dim=out_dim, qkv_bias=qkv_bias, qk_dim=qk_dim, qk_scale=qk_scale, attn_drop=attn_drop, drop=drop, in_rpe_dim=in_rpe_dim, k_rpe=k_rpe, q_rpe=q_rpe, v_rpe=v_rpe, heads_share_rpe=heads_share_rpe) # Each head will learn its own query and all parent nodes will # use these same queries. self.q = LearnableParameter(torch.zeros(qk_dim * num_heads)) # `init_weights` initializes the weights with a truncated normal # distribution init_weights(self.q) def _get_query(self, x_parent): """Build queries from learnable queries. The parent features are simply used to get the number of parent nodes and expand the learnt queries accordingly. :param x_parent: Tensor of shape (Np, Cp) Node features for the parent nodes :return: Tensor of shape (Np, D * H) """ Np = x_parent.shape[0] return self.q.repeat(Np, 1)