Other
English
File size: 11,332 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import torch
from torch import nn
from torch_geometric.nn.aggr import SumAggregation
from torch_geometric.nn.aggr import MeanAggregation
from torch_geometric.nn.aggr import MaxAggregation
from torch_geometric.nn.aggr import MinAggregation
from torch_scatter import scatter_sum
from torch_geometric.utils import softmax
from src.utils.nn import init_weights, LearnableParameter, build_qk_scale_func


__all__ = [
    'pool_factory', 'SumPool', 'MeanPool', 'MaxPool', 'MinPool',
    'AttentivePool', 'AttentivePoolWithLearntQueries']


def pool_factory(pool, *args, **kwargs):
    """Build a Pool module from string or from an existing module. This
    helper is intended to be used as a helper in spt and Stage
    constructors.
    """
    if isinstance(pool, (AggregationPoolMixIn, BaseAttentivePool)):
        return pool
    if pool == 'max':
        return MaxPool()
    if pool == 'min':
        return MinPool()
    if pool == 'mean':
        return MeanPool()
    if pool == 'sum':
        return SumPool()
    return pool(*args, **kwargs)


class AggregationPoolMixIn:
    """MixIn class to convert torch-geometric Aggregation modules into
    Pool modules with our desired forward signature.

    :param x_child: Tensor of shape (Nc, Cc)
        Node features for the children nodes
    :param x_parent: Any
        Not used for Aggregation
    :param index: LongTensor of shape (Nc)
        Indices indicating the parent of each for each child node
    :param edge_attr: Any
        Not used for Aggregation
    :param num_pool: int
        Number of parent nodes Nc. If not provided, will be inferred
        from `index.max() + 1`
    :return:
    """
    def __call__(self, x_child, x_parent, index, edge_attr=None, num_pool=None):
        return super().__call__(x_child, index=index, dim_size=num_pool)


class SumPool(AggregationPoolMixIn, SumAggregation):
    pass


class MeanPool(AggregationPoolMixIn, MeanAggregation):
    pass


class MaxPool(AggregationPoolMixIn, MaxAggregation):
    pass


class MinPool(AggregationPoolMixIn, MinAggregation):
    pass


class BaseAttentivePool(nn.Module):
    """Base class for attentive pooling classes. This class is not
    intended to be instantiated, but avoids duplicating code between
    similar child classes, which are expected to implement:
      - `_get_query()`
    """

    # TODO: this module could be used for pooling from one segment level
    #  to the next. But requires defining how. With QKV paradigm ? Then
    #  how to define Q for superpoints ? from max-pooled/mean-pooled
    #  features ? from handcrafted features ? If not QKV, simply have a
    #  FFN predict (multi-headed) attention scores to be softmaxed ? How
    #  to guide pooling from the above level (same pb as for qkv) ?

    # TODO: see torch_geometric SoftmaxAggregation and
    #  AttentionalAggregation for possibilities. Among which, a
    #  learnable softmax temperature

    def __init__(
            self,
            dim=None,
            num_heads=1,
            in_dim=None,
            out_dim=None,
            qkv_bias=True,
            qk_dim=8,
            qk_scale=None,
            attn_drop=None,
            drop=None,
            in_rpe_dim=9,
            k_rpe=False,
            q_rpe=False,
            v_rpe=False,
            heads_share_rpe=False):
        super().__init__()

        assert dim % num_heads == 0, f"dim must be a multiple of num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.qk_dim = qk_dim
        self.qk_scale = build_qk_scale_func(dim, num_heads, qk_scale)
        self.heads_share_rpe = heads_share_rpe

        self.kv = nn.Linear(dim, qk_dim * num_heads + dim, bias=qkv_bias)

        # Build the RPE encoders, with the option of sharing weights
        # across all heads
        rpe_dim = qk_dim if heads_share_rpe else qk_dim * num_heads

        if not isinstance(k_rpe, bool):
            self.k_rpe = k_rpe
        else:
            self.k_rpe = nn.Linear(in_rpe_dim, rpe_dim) if k_rpe else None

        if not isinstance(q_rpe, bool):
            self.q_rpe = q_rpe
        else:
            self.q_rpe = nn.Linear(in_rpe_dim, rpe_dim) if q_rpe else None

        if v_rpe:
            raise NotImplementedError

        self.in_proj = nn.Linear(in_dim, dim) if in_dim is not None else None
        self.out_proj = nn.Linear(dim, out_dim) if out_dim is not None else None

        self.attn_drop = nn.Dropout(attn_drop) \
            if attn_drop is not None and attn_drop > 0 else None
        self.out_drop = nn.Dropout(drop) \
            if drop is not None and drop > 0 else None

    def forward(
            self, x_child, x_parent, index, edge_attr=None, num_pool=None):
        """
        :param x_child: Tensor of shape (Nc, Cc)
            Node features for the children nodes
        :param x_parent: Tensor of shape (Np, Cp)
            Node features for the parent nodes
        :param index: LongTensor of shape (Nc)
            Indices indicating the parent of each for each child node
        :param edge_attr: FloatTensor or shape (Nc, F)
            Edge attributes for relative pose encoding
        :param num_pool: int
            Number of parent nodes Nc. If not provided, will be inferred
            from the shape of x_parent
        :return:
        """
        Nc = x_child.shape[0]
        Np = x_parent.shape[0] if num_pool is None else num_pool
        H = self.num_heads
        D = self.qk_dim
        DH = D * H

        # Optional linear projection of features
        if self.in_proj is not None:
            x_child = self.in_proj(x_child)

        # Compute queries from parent features
        q = self._get_query(x_parent)  # [Np, DH]

        # Compute keys and values from child features
        kv = self.kv(x_child)  # [Nc, DH + C]

        # Expand queries and separate keys and values
        q = q[index].view(Nc, H, D)     # [Nc, H, D]
        k = kv[:, :DH].view(Nc, H, D)   # [Nc, H, D]
        v = kv[:, DH:].view(Nc, H, -1)  # [Nc, H, C // H]

        # Apply scaling on the queries
        q = q * self.qk_scale(index)

        # TODO: add the relative positional encodings to the
        #  compatibilities here
        #  - k_rpe, q_rpe, v_rpe
        #  - pos difference, absolute distance, squared distance, centroid distance, edge distance, ...
        #  - with/out edge attributes
        #  - mlp (L-LN-A-L), learnable lookup table (see Stratified Transformer)
        #  - scalar rpe, vector rpe (see Stratified Transformer)
        if self.k_rpe is not None:
            rpe = self.k_rpe(edge_attr)

            # Expand RPE to all heads if heads share the RPE encoder
            if self.heads_share_rpe:
                rpe = rpe.repeat(1, H)

            k = k + rpe.view(Nc, H, -1)

        if self.q_rpe is not None:
            rpe = self.q_rpe(edge_attr)

            # Expand RPE to all heads if heads share the RPE encoder
            if self.heads_share_rpe:
                rpe = rpe.repeat(1, H)

            q = q + rpe.view(Nc, H, -1)

        # Compute compatibility scores from the query-key products
        compat = torch.einsum('nhd, nhd -> nh', q, k)  # [Nc, H]

        # Compute the attention scores with scaled softmax
        attn = softmax(compat, index=index, dim=0, num_nodes=Np)  # [Nc, H]

        # Optional attention dropout
        if self.attn_drop is not None:
            attn = self.attn_drop(attn)

        # Apply the attention on the values
        x = (v * attn.unsqueeze(-1)).view(Nc, self.dim)  # [Nc, C]
        x = scatter_sum(x, index, dim=0, dim_size=Np)  # [Np, C]

        # Optional linear projection of features
        if self.out_proj is not None:
            x = self.out_proj(x)  # [Np, out_dim]

        # Optional dropout on projection of features
        if self.out_drop is not None:
            x = self.out_drop(x)  # [Np, C] or [Np, out_dim]

        return x

    def _get_query(self, x_parent):
        """Overwrite this method to implement the attentive pooling.

        :param x_parent: Tensor of shape (Np, Cp)
            Node features for the parent nodes

        :return: Tensor of shape (Np, D * H)
        """
        raise NotImplementedError

    def extra_repr(self) -> str:
        return f'dim={self.dim}, num_heads={self.num_heads}'


class AttentivePool(BaseAttentivePool):
    def __init__(
            self,
            dim=None,
            q_in_dim=None,
            num_heads=1,
            in_dim=None,
            out_dim=None,
            qkv_bias=True,
            qk_dim=8,
            qk_scale=None,
            attn_drop=None,
            drop=None,
            in_rpe_dim=9,
            k_rpe=False,
            q_rpe=False,
            v_rpe=False,
            heads_share_rpe=False):
        super().__init__(
            dim=dim,
            num_heads=num_heads,
            in_dim=in_dim,
            out_dim=out_dim,
            qkv_bias=qkv_bias,
            qk_dim=qk_dim,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            drop=drop,
            in_rpe_dim=in_rpe_dim,
            k_rpe=k_rpe,
            q_rpe=q_rpe,
            v_rpe=v_rpe,
            heads_share_rpe=heads_share_rpe)

        # Queries will be built from input parent feature
        self.q = nn.Linear(q_in_dim, qk_dim * num_heads, bias=qkv_bias)  # TODO: use FFN heare to deal with handcrafted features

    def _get_query(self, x_parent):
        """Build queries from input parent features

        :param x_parent: Tensor of shape (Np, Cp)
            Node features for the parent nodes

        :return: Tensor of shape (Np, D * H)
        """
        return self.q(x_parent)  # [Np, DH]



class AttentivePoolWithLearntQueries(BaseAttentivePool):
    def __init__(
            self,
            dim=None,
            num_heads=1,
            in_dim=None,
            out_dim=None,
            qkv_bias=True,
            qk_dim=8,
            qk_scale=None,
            attn_drop=None,
            drop=None,
            in_rpe_dim=18,
            k_rpe=False,
            q_rpe=False,
            v_rpe=False,
            heads_share_rpe=False):
        super().__init__(
            dim=dim,
            num_heads=num_heads,
            in_dim=in_dim,
            out_dim=out_dim,
            qkv_bias=qkv_bias,
            qk_dim=qk_dim,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            drop=drop,
            in_rpe_dim=in_rpe_dim,
            k_rpe=k_rpe,
            q_rpe=q_rpe,
            v_rpe=v_rpe,
            heads_share_rpe=heads_share_rpe)

        # Each head will learn its own query and all parent nodes will
        # use these same queries.
        self.q = LearnableParameter(torch.zeros(qk_dim * num_heads))

        # `init_weights` initializes the weights with a truncated normal
        # distribution
        init_weights(self.q)

    def _get_query(self, x_parent):
        """Build queries from learnable queries. The parent features are
        simply used to get the number of parent nodes and expand the
        learnt queries accordingly.

        :param x_parent: Tensor of shape (Np, Cp)
            Node features for the parent nodes

        :return: Tensor of shape (Np, D * H)
        """
        Np = x_parent.shape[0]
        return self.q.repeat(Np, 1)