English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
from torch import nn
from torch_geometric.nn.aggr import SumAggregation
from torch_geometric.nn.aggr import MeanAggregation
from torch_geometric.nn.aggr import MaxAggregation
from torch_geometric.nn.aggr import MinAggregation
from torch_scatter import scatter_sum
from torch_geometric.utils import softmax
from src.utils.nn import init_weights, LearnableParameter, build_qk_scale_func
__all__ = [
'pool_factory', 'SumPool', 'MeanPool', 'MaxPool', 'MinPool',
'AttentivePool', 'AttentivePoolWithLearntQueries']
def pool_factory(pool, *args, **kwargs):
"""Build a Pool module from string or from an existing module. This
helper is intended to be used as a helper in spt and Stage
constructors.
"""
if isinstance(pool, (AggregationPoolMixIn, BaseAttentivePool)):
return pool
if pool == 'max':
return MaxPool()
if pool == 'min':
return MinPool()
if pool == 'mean':
return MeanPool()
if pool == 'sum':
return SumPool()
return pool(*args, **kwargs)
class AggregationPoolMixIn:
"""MixIn class to convert torch-geometric Aggregation modules into
Pool modules with our desired forward signature.
:param x_child: Tensor of shape (Nc, Cc)
Node features for the children nodes
:param x_parent: Any
Not used for Aggregation
:param index: LongTensor of shape (Nc)
Indices indicating the parent of each for each child node
:param edge_attr: Any
Not used for Aggregation
:param num_pool: int
Number of parent nodes Nc. If not provided, will be inferred
from `index.max() + 1`
:return:
"""
def __call__(self, x_child, x_parent, index, edge_attr=None, num_pool=None):
return super().__call__(x_child, index=index, dim_size=num_pool)
class SumPool(AggregationPoolMixIn, SumAggregation):
pass
class MeanPool(AggregationPoolMixIn, MeanAggregation):
pass
class MaxPool(AggregationPoolMixIn, MaxAggregation):
pass
class MinPool(AggregationPoolMixIn, MinAggregation):
pass
class BaseAttentivePool(nn.Module):
"""Base class for attentive pooling classes. This class is not
intended to be instantiated, but avoids duplicating code between
similar child classes, which are expected to implement:
- `_get_query()`
"""
# TODO: this module could be used for pooling from one segment level
# to the next. But requires defining how. With QKV paradigm ? Then
# how to define Q for superpoints ? from max-pooled/mean-pooled
# features ? from handcrafted features ? If not QKV, simply have a
# FFN predict (multi-headed) attention scores to be softmaxed ? How
# to guide pooling from the above level (same pb as for qkv) ?
# TODO: see torch_geometric SoftmaxAggregation and
# AttentionalAggregation for possibilities. Among which, a
# learnable softmax temperature
def __init__(
self,
dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=9,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__()
assert dim % num_heads == 0, f"dim must be a multiple of num_heads"
self.dim = dim
self.num_heads = num_heads
self.qk_dim = qk_dim
self.qk_scale = build_qk_scale_func(dim, num_heads, qk_scale)
self.heads_share_rpe = heads_share_rpe
self.kv = nn.Linear(dim, qk_dim * num_heads + dim, bias=qkv_bias)
# Build the RPE encoders, with the option of sharing weights
# across all heads
rpe_dim = qk_dim if heads_share_rpe else qk_dim * num_heads
if not isinstance(k_rpe, bool):
self.k_rpe = k_rpe
else:
self.k_rpe = nn.Linear(in_rpe_dim, rpe_dim) if k_rpe else None
if not isinstance(q_rpe, bool):
self.q_rpe = q_rpe
else:
self.q_rpe = nn.Linear(in_rpe_dim, rpe_dim) if q_rpe else None
if v_rpe:
raise NotImplementedError
self.in_proj = nn.Linear(in_dim, dim) if in_dim is not None else None
self.out_proj = nn.Linear(dim, out_dim) if out_dim is not None else None
self.attn_drop = nn.Dropout(attn_drop) \
if attn_drop is not None and attn_drop > 0 else None
self.out_drop = nn.Dropout(drop) \
if drop is not None and drop > 0 else None
def forward(
self, x_child, x_parent, index, edge_attr=None, num_pool=None):
"""
:param x_child: Tensor of shape (Nc, Cc)
Node features for the children nodes
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:param index: LongTensor of shape (Nc)
Indices indicating the parent of each for each child node
:param edge_attr: FloatTensor or shape (Nc, F)
Edge attributes for relative pose encoding
:param num_pool: int
Number of parent nodes Nc. If not provided, will be inferred
from the shape of x_parent
:return:
"""
Nc = x_child.shape[0]
Np = x_parent.shape[0] if num_pool is None else num_pool
H = self.num_heads
D = self.qk_dim
DH = D * H
# Optional linear projection of features
if self.in_proj is not None:
x_child = self.in_proj(x_child)
# Compute queries from parent features
q = self._get_query(x_parent) # [Np, DH]
# Compute keys and values from child features
kv = self.kv(x_child) # [Nc, DH + C]
# Expand queries and separate keys and values
q = q[index].view(Nc, H, D) # [Nc, H, D]
k = kv[:, :DH].view(Nc, H, D) # [Nc, H, D]
v = kv[:, DH:].view(Nc, H, -1) # [Nc, H, C // H]
# Apply scaling on the queries
q = q * self.qk_scale(index)
# TODO: add the relative positional encodings to the
# compatibilities here
# - k_rpe, q_rpe, v_rpe
# - pos difference, absolute distance, squared distance, centroid distance, edge distance, ...
# - with/out edge attributes
# - mlp (L-LN-A-L), learnable lookup table (see Stratified Transformer)
# - scalar rpe, vector rpe (see Stratified Transformer)
if self.k_rpe is not None:
rpe = self.k_rpe(edge_attr)
# Expand RPE to all heads if heads share the RPE encoder
if self.heads_share_rpe:
rpe = rpe.repeat(1, H)
k = k + rpe.view(Nc, H, -1)
if self.q_rpe is not None:
rpe = self.q_rpe(edge_attr)
# Expand RPE to all heads if heads share the RPE encoder
if self.heads_share_rpe:
rpe = rpe.repeat(1, H)
q = q + rpe.view(Nc, H, -1)
# Compute compatibility scores from the query-key products
compat = torch.einsum('nhd, nhd -> nh', q, k) # [Nc, H]
# Compute the attention scores with scaled softmax
attn = softmax(compat, index=index, dim=0, num_nodes=Np) # [Nc, H]
# Optional attention dropout
if self.attn_drop is not None:
attn = self.attn_drop(attn)
# Apply the attention on the values
x = (v * attn.unsqueeze(-1)).view(Nc, self.dim) # [Nc, C]
x = scatter_sum(x, index, dim=0, dim_size=Np) # [Np, C]
# Optional linear projection of features
if self.out_proj is not None:
x = self.out_proj(x) # [Np, out_dim]
# Optional dropout on projection of features
if self.out_drop is not None:
x = self.out_drop(x) # [Np, C] or [Np, out_dim]
return x
def _get_query(self, x_parent):
"""Overwrite this method to implement the attentive pooling.
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
raise NotImplementedError
def extra_repr(self) -> str:
return f'dim={self.dim}, num_heads={self.num_heads}'
class AttentivePool(BaseAttentivePool):
def __init__(
self,
dim=None,
q_in_dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=9,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__(
dim=dim,
num_heads=num_heads,
in_dim=in_dim,
out_dim=out_dim,
qkv_bias=qkv_bias,
qk_dim=qk_dim,
qk_scale=qk_scale,
attn_drop=attn_drop,
drop=drop,
in_rpe_dim=in_rpe_dim,
k_rpe=k_rpe,
q_rpe=q_rpe,
v_rpe=v_rpe,
heads_share_rpe=heads_share_rpe)
# Queries will be built from input parent feature
self.q = nn.Linear(q_in_dim, qk_dim * num_heads, bias=qkv_bias) # TODO: use FFN heare to deal with handcrafted features
def _get_query(self, x_parent):
"""Build queries from input parent features
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
return self.q(x_parent) # [Np, DH]
class AttentivePoolWithLearntQueries(BaseAttentivePool):
def __init__(
self,
dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=18,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__(
dim=dim,
num_heads=num_heads,
in_dim=in_dim,
out_dim=out_dim,
qkv_bias=qkv_bias,
qk_dim=qk_dim,
qk_scale=qk_scale,
attn_drop=attn_drop,
drop=drop,
in_rpe_dim=in_rpe_dim,
k_rpe=k_rpe,
q_rpe=q_rpe,
v_rpe=v_rpe,
heads_share_rpe=heads_share_rpe)
# Each head will learn its own query and all parent nodes will
# use these same queries.
self.q = LearnableParameter(torch.zeros(qk_dim * num_heads))
# `init_weights` initializes the weights with a truncated normal
# distribution
init_weights(self.q)
def _get_query(self, x_parent):
"""Build queries from learnable queries. The parent features are
simply used to get the number of parent nodes and expand the
learnt queries accordingly.
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
Np = x_parent.shape[0]
return self.q.repeat(Np, 1)