File size: 11,332 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
import torch
from torch import nn
from torch_geometric.nn.aggr import SumAggregation
from torch_geometric.nn.aggr import MeanAggregation
from torch_geometric.nn.aggr import MaxAggregation
from torch_geometric.nn.aggr import MinAggregation
from torch_scatter import scatter_sum
from torch_geometric.utils import softmax
from src.utils.nn import init_weights, LearnableParameter, build_qk_scale_func
__all__ = [
'pool_factory', 'SumPool', 'MeanPool', 'MaxPool', 'MinPool',
'AttentivePool', 'AttentivePoolWithLearntQueries']
def pool_factory(pool, *args, **kwargs):
"""Build a Pool module from string or from an existing module. This
helper is intended to be used as a helper in spt and Stage
constructors.
"""
if isinstance(pool, (AggregationPoolMixIn, BaseAttentivePool)):
return pool
if pool == 'max':
return MaxPool()
if pool == 'min':
return MinPool()
if pool == 'mean':
return MeanPool()
if pool == 'sum':
return SumPool()
return pool(*args, **kwargs)
class AggregationPoolMixIn:
"""MixIn class to convert torch-geometric Aggregation modules into
Pool modules with our desired forward signature.
:param x_child: Tensor of shape (Nc, Cc)
Node features for the children nodes
:param x_parent: Any
Not used for Aggregation
:param index: LongTensor of shape (Nc)
Indices indicating the parent of each for each child node
:param edge_attr: Any
Not used for Aggregation
:param num_pool: int
Number of parent nodes Nc. If not provided, will be inferred
from `index.max() + 1`
:return:
"""
def __call__(self, x_child, x_parent, index, edge_attr=None, num_pool=None):
return super().__call__(x_child, index=index, dim_size=num_pool)
class SumPool(AggregationPoolMixIn, SumAggregation):
pass
class MeanPool(AggregationPoolMixIn, MeanAggregation):
pass
class MaxPool(AggregationPoolMixIn, MaxAggregation):
pass
class MinPool(AggregationPoolMixIn, MinAggregation):
pass
class BaseAttentivePool(nn.Module):
"""Base class for attentive pooling classes. This class is not
intended to be instantiated, but avoids duplicating code between
similar child classes, which are expected to implement:
- `_get_query()`
"""
# TODO: this module could be used for pooling from one segment level
# to the next. But requires defining how. With QKV paradigm ? Then
# how to define Q for superpoints ? from max-pooled/mean-pooled
# features ? from handcrafted features ? If not QKV, simply have a
# FFN predict (multi-headed) attention scores to be softmaxed ? How
# to guide pooling from the above level (same pb as for qkv) ?
# TODO: see torch_geometric SoftmaxAggregation and
# AttentionalAggregation for possibilities. Among which, a
# learnable softmax temperature
def __init__(
self,
dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=9,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__()
assert dim % num_heads == 0, f"dim must be a multiple of num_heads"
self.dim = dim
self.num_heads = num_heads
self.qk_dim = qk_dim
self.qk_scale = build_qk_scale_func(dim, num_heads, qk_scale)
self.heads_share_rpe = heads_share_rpe
self.kv = nn.Linear(dim, qk_dim * num_heads + dim, bias=qkv_bias)
# Build the RPE encoders, with the option of sharing weights
# across all heads
rpe_dim = qk_dim if heads_share_rpe else qk_dim * num_heads
if not isinstance(k_rpe, bool):
self.k_rpe = k_rpe
else:
self.k_rpe = nn.Linear(in_rpe_dim, rpe_dim) if k_rpe else None
if not isinstance(q_rpe, bool):
self.q_rpe = q_rpe
else:
self.q_rpe = nn.Linear(in_rpe_dim, rpe_dim) if q_rpe else None
if v_rpe:
raise NotImplementedError
self.in_proj = nn.Linear(in_dim, dim) if in_dim is not None else None
self.out_proj = nn.Linear(dim, out_dim) if out_dim is not None else None
self.attn_drop = nn.Dropout(attn_drop) \
if attn_drop is not None and attn_drop > 0 else None
self.out_drop = nn.Dropout(drop) \
if drop is not None and drop > 0 else None
def forward(
self, x_child, x_parent, index, edge_attr=None, num_pool=None):
"""
:param x_child: Tensor of shape (Nc, Cc)
Node features for the children nodes
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:param index: LongTensor of shape (Nc)
Indices indicating the parent of each for each child node
:param edge_attr: FloatTensor or shape (Nc, F)
Edge attributes for relative pose encoding
:param num_pool: int
Number of parent nodes Nc. If not provided, will be inferred
from the shape of x_parent
:return:
"""
Nc = x_child.shape[0]
Np = x_parent.shape[0] if num_pool is None else num_pool
H = self.num_heads
D = self.qk_dim
DH = D * H
# Optional linear projection of features
if self.in_proj is not None:
x_child = self.in_proj(x_child)
# Compute queries from parent features
q = self._get_query(x_parent) # [Np, DH]
# Compute keys and values from child features
kv = self.kv(x_child) # [Nc, DH + C]
# Expand queries and separate keys and values
q = q[index].view(Nc, H, D) # [Nc, H, D]
k = kv[:, :DH].view(Nc, H, D) # [Nc, H, D]
v = kv[:, DH:].view(Nc, H, -1) # [Nc, H, C // H]
# Apply scaling on the queries
q = q * self.qk_scale(index)
# TODO: add the relative positional encodings to the
# compatibilities here
# - k_rpe, q_rpe, v_rpe
# - pos difference, absolute distance, squared distance, centroid distance, edge distance, ...
# - with/out edge attributes
# - mlp (L-LN-A-L), learnable lookup table (see Stratified Transformer)
# - scalar rpe, vector rpe (see Stratified Transformer)
if self.k_rpe is not None:
rpe = self.k_rpe(edge_attr)
# Expand RPE to all heads if heads share the RPE encoder
if self.heads_share_rpe:
rpe = rpe.repeat(1, H)
k = k + rpe.view(Nc, H, -1)
if self.q_rpe is not None:
rpe = self.q_rpe(edge_attr)
# Expand RPE to all heads if heads share the RPE encoder
if self.heads_share_rpe:
rpe = rpe.repeat(1, H)
q = q + rpe.view(Nc, H, -1)
# Compute compatibility scores from the query-key products
compat = torch.einsum('nhd, nhd -> nh', q, k) # [Nc, H]
# Compute the attention scores with scaled softmax
attn = softmax(compat, index=index, dim=0, num_nodes=Np) # [Nc, H]
# Optional attention dropout
if self.attn_drop is not None:
attn = self.attn_drop(attn)
# Apply the attention on the values
x = (v * attn.unsqueeze(-1)).view(Nc, self.dim) # [Nc, C]
x = scatter_sum(x, index, dim=0, dim_size=Np) # [Np, C]
# Optional linear projection of features
if self.out_proj is not None:
x = self.out_proj(x) # [Np, out_dim]
# Optional dropout on projection of features
if self.out_drop is not None:
x = self.out_drop(x) # [Np, C] or [Np, out_dim]
return x
def _get_query(self, x_parent):
"""Overwrite this method to implement the attentive pooling.
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
raise NotImplementedError
def extra_repr(self) -> str:
return f'dim={self.dim}, num_heads={self.num_heads}'
class AttentivePool(BaseAttentivePool):
def __init__(
self,
dim=None,
q_in_dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=9,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__(
dim=dim,
num_heads=num_heads,
in_dim=in_dim,
out_dim=out_dim,
qkv_bias=qkv_bias,
qk_dim=qk_dim,
qk_scale=qk_scale,
attn_drop=attn_drop,
drop=drop,
in_rpe_dim=in_rpe_dim,
k_rpe=k_rpe,
q_rpe=q_rpe,
v_rpe=v_rpe,
heads_share_rpe=heads_share_rpe)
# Queries will be built from input parent feature
self.q = nn.Linear(q_in_dim, qk_dim * num_heads, bias=qkv_bias) # TODO: use FFN heare to deal with handcrafted features
def _get_query(self, x_parent):
"""Build queries from input parent features
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
return self.q(x_parent) # [Np, DH]
class AttentivePoolWithLearntQueries(BaseAttentivePool):
def __init__(
self,
dim=None,
num_heads=1,
in_dim=None,
out_dim=None,
qkv_bias=True,
qk_dim=8,
qk_scale=None,
attn_drop=None,
drop=None,
in_rpe_dim=18,
k_rpe=False,
q_rpe=False,
v_rpe=False,
heads_share_rpe=False):
super().__init__(
dim=dim,
num_heads=num_heads,
in_dim=in_dim,
out_dim=out_dim,
qkv_bias=qkv_bias,
qk_dim=qk_dim,
qk_scale=qk_scale,
attn_drop=attn_drop,
drop=drop,
in_rpe_dim=in_rpe_dim,
k_rpe=k_rpe,
q_rpe=q_rpe,
v_rpe=v_rpe,
heads_share_rpe=heads_share_rpe)
# Each head will learn its own query and all parent nodes will
# use these same queries.
self.q = LearnableParameter(torch.zeros(qk_dim * num_heads))
# `init_weights` initializes the weights with a truncated normal
# distribution
init_weights(self.q)
def _get_query(self, x_parent):
"""Build queries from learnable queries. The parent features are
simply used to get the number of parent nodes and expand the
learnt queries accordingly.
:param x_parent: Tensor of shape (Np, Cp)
Node features for the parent nodes
:return: Tensor of shape (Np, D * H)
"""
Np = x_parent.shape[0]
return self.q.repeat(Np, 1)
|