English
File size: 23,467 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
import torch
from src.data import Data, NAG, CSRData
from src.transforms import Transform
from src.utils import tensor_idx, to_float_rgb, to_byte_rgb, dropout, \
    sanitize_keys


__all__ = [
    'DataToNAG', 'NAGToData', 'Cast', 'NAGCast', 'RemoveKeys', 'NAGRemoveKeys',
    'AddKeysTo', 'NAGAddKeysTo', 'NAGSelectByKey', 'SelectColumns',
    'NAGSelectColumns', 'DropoutColumns', 'NAGDropoutColumns', 'DropoutRows',
    'NAGDropoutRows', 'NAGJitterKey']


class DataToNAG(Transform):
    """Convert Data to a single-level NAG."""

    _IN_TYPE = Data
    _OUT_TYPE = NAG

    def _process(self, data):
        return NAG([data])


class NAGToData(Transform):
    """Convert a single-level NAG to Data."""

    _IN_TYPE = NAG
    _OUT_TYPE = Data

    def _process(self, nag):
        assert nag.num_levels == 1
        return nag[0]


class Cast(Transform):
    """Cast Data attributes to the provided integer and floating point
    dtypes. In case 'rgb' or 'mean_rgb' is found, `rgb_to_float` will
    decide whether it should be cast to 'fp_dtype' or 'uint8'.
    """

    def __init__(
            self,
            fp_dtype=torch.float,
            int_dtype=torch.long,
            rgb_to_float=True):
        self.fp_dtype = fp_dtype
        self.int_dtype = int_dtype
        self.rgb_to_float = rgb_to_float

    def _process(self, data):
        for k in data.keys:

            # Recursively deal with CSRData attributes (e.g. Cluster for
            # 'sub' key)
            if isinstance(data[k], CSRData):
                values = []
                for v in data[k].values:
                    values.append(self._process(Data(foo=v)).foo)
                data[k].values = values
                data[k].pointers = data[k].pointers.long()
                continue

            # Deal with 'rgb' and 'mean_rgb' attribute
            if k in ['rgb', 'mean_rgb']:
                data[k] = to_float_rgb(data[k]).to(self.fp_dtype)\
                    if self.rgb_to_float else to_byte_rgb(data[k])
                continue

            # Deal with Tensor attributes
            if isinstance(data[k], torch.Tensor):
                data[k] = data[k].to(self.fp_dtype) \
                    if data[k].is_floating_point() \
                    else data[k].to(self.int_dtype)
                continue

            # Other objects are left untouched

        return data


class NAGCast(Cast):
    """Cast NAG attributes to the provided integer and floating point
    dtypes. In case 'rgb' or 'mean_rgb' is found and is not a floating
    point tensor, `rgb_to_float` will decide whether it should be cast
    to floats.
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def _process(self, nag):
        transform = Cast(
            fp_dtype=self.fp_dtype,
            int_dtype=self.int_dtype,
            rgb_to_float=self.rgb_to_float)

        for i_level in range(nag.num_levels):
            nag._list[i_level] = transform(nag[i_level])

        return nag


class RemoveKeys(Transform):
    """Remove attributes of a Data object based on their name.

    :param keys: str of list(str)
        List of attribute names
    :param strict: bool
        If True, will raise an exception if an attribute from key is
        not within the input Data keys
    """

    _NO_REPR = ['strict']

    def __init__(self, keys=None, strict=False):
        self.keys = sanitize_keys(keys, default=[])
        self.strict = strict

    def _process(self, data):
        keys = set(data.keys)
        for k in self.keys:
            if k not in keys and self.strict:
                raise Exception(f"key: {k} is not within Data keys: {keys}")
        for k in self.keys:
            delattr(data, k)
        return data


class NAGRemoveKeys(Transform):
    """Remove attributes of a NAG object based on their name.

    :param level: int or str
        Level at which to remove attributes. Can be an int or a str. If
        the latter, 'all' will apply on all levels, 'i+' will apply on
        level-i and above, 'i-' will apply on level-i and below
    :param keys: str or list(str)
        List of attribute names
    :param strict: bool=False
        If True, will raise an exception if an attribute from key is
        not within the input Data keys
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG
    _NO_REPR = ['strict']

    def __init__(self, level='all', keys=None, strict=False):
        assert isinstance(level, (int, str))
        self.level = level
        self.keys = sanitize_keys(keys, default=[])
        self.strict = strict

    def _process(self, nag):

        level_keys = [[]] * nag.num_levels
        if isinstance(self.level, int):
            level_keys[self.level] = self.keys
        elif self.level == 'all':
            level_keys = [self.keys] * nag.num_levels
        elif self.level[-1] == '+':
            i = int(self.level[:-1])
            level_keys[i:] = [self.keys] * (nag.num_levels - i)
        elif self.level[-1] == '-':
            i = int(self.level[:-1])
            level_keys[:i] = [self.keys] * i
        else:
            raise ValueError(f'Unsupported level={self.level}')

        transforms = [
            RemoveKeys(keys=k, strict=self.strict) for k in level_keys]

        for i_level in range(nag.num_levels):
            nag._list[i_level] = transforms[i_level](nag._list[i_level])

        return nag


class AddKeysTo(Transform):
    """Get attributes from their keys and concatenate them to x.

    :param keys: str or list(str)
        The feature concatenated to 'to'
    :param to: str
        Destination attribute where the features in 'keys' will be
        concatenated
    :param strict: bool
        Whether we want to raise an error if a key is not found
    :param delete_after: bool
        Whether the Data attributes should be removed once added to 'to'
    """

    _NO_REPR = ['strict']

    def __init__(self, keys=None, to='x', strict=True, delete_after=True):
        self.keys = [keys] if isinstance(keys, str) else keys
        self.to = to
        self.strict = strict
        self.delete_after = delete_after

    def _process_single_key(self, data, key, to):
        # Read existing features and the attribute of interest
        feat = getattr(data, key, None)
        x = getattr(data, to, None)

        # Skip if the attribute is None
        if feat is None:
            if self.strict:
                raise Exception(f"Data should contain the attribute '{key}'")
            else:
                return data

        # Remove the attribute from the Data, if required
        if self.delete_after:
            delattr(data, key)

        # In case Data has no features yet
        if x is None:
            if self.strict and data.num_nodes != feat.shape[0]:
                raise Exception(f"Data should contain the attribute '{to}'")
            if feat.dim() == 1:
                feat = feat.unsqueeze(-1)
            data[to] = feat
            return data

        # Make sure shapes match
        if x.shape[0] != feat.shape[0]:
            raise Exception(
                f"The tensors '{to}' and '{key}' can't be concatenated, "
                f"'{to}': {x.shape[0]}, '{key}': {feat.shape[0]}")

        # Concatenate x and feat
        if x.dim() == 1:
            x = x.unsqueeze(-1)
        if feat.dim() == 1:
            feat = feat.unsqueeze(-1)
        data[to] = torch.cat([x, feat], dim=-1)

        return data

    def _process(self, data):
        if self.keys is None or len(self.keys) == 0:
            return data

        for key in self.keys:
            data = self._process_single_key(data, key, self.to)

        return data


class NAGAddKeysTo(Transform):
    """Get attributes from their keys and concatenate them to x.

    :param level: int or str
        Level at which to remove attributes. Can be an int or a str. If
        the latter, 'all' will apply on all levels, 'i+' will apply on
        level-i and above, 'i-' will apply on level-i and below
    :param keys: str or list(str)
        The feature concatenated to 'to'
    :param to: str
        Destination attribute where the features in 'keys' will be
        concatenated
    :param strict: bool
        Whether we want to raise an error if a key is not found
    :param delete_after: bool
        Whether the Data attributes should be removed once added to 'to'
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG
    _NO_REPR = ['strict']

    def __init__(
            self, level='all', keys=None, to='x', strict=True,
            delete_after=True):
        self.level = level
        self.keys = [keys] if isinstance(keys, str) else keys
        self.to = to
        self.strict = strict
        self.delete_after = delete_after

    def _process(self, nag):

        level_keys = [[]] * nag.num_levels
        if isinstance(self.level, int):
            level_keys[self.level] = self.keys
        elif self.level == 'all':
            level_keys = [self.keys] * nag.num_levels
        elif self.level[-1] == '+':
            i = int(self.level[:-1])
            level_keys[i:] = [self.keys] * (nag.num_levels - i)
        elif self.level[-1] == '-':
            i = int(self.level[:-1])
            level_keys[:i] = [self.keys] * i
        else:
            raise ValueError(f'Unsupported level={self.level}')

        transforms = [
            AddKeysTo(
                keys=k, to=self.to, strict=self.strict,
                delete_after=self.delete_after)
            for k in level_keys]

        for i_level in range(nag.num_levels):
            nag._list[i_level] = transforms[i_level](nag._list[i_level])

        return nag


class NAGSelectByKey(Transform):
    """Select the i-level nodes based on a key. The corresponding key is
    expected to exist in the i-level attributes and should hold a 1D
    boolean mask.

    :param key: str
        Key attribute expected to be found in the input NAG's `level`.
        The `key` attribute should carry a 1D boolean mask over the
        `level` nodes
    :param level: int
        NAG level based on which to operate the selection
    :param negation: bool
        Whether the mask or its complementary should be used
    :param strict: bool
        Whether we want to raise an error if the key is not found or if
        it does not carry a 1D boolean mask
    :param delete_after: bool
        Whether the `key` attribute should be removed after selection
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG
    _NO_REPR = ['strict']

    def __init__(
            self, key=None, level=0, negation=False, strict=True,
            delete_after=True):
        assert key is not None
        self.key = key
        self.level = level
        self.negation = negation
        self.strict = strict
        self.delete_after = delete_after

    def _process(self, nag):
        # Ensure the key exists
        if self.key not in nag[self.level].keys:
            if self.strict:
                raise ValueError(
                    f'Input NAG does not have `{self.key}` attribute at '
                    f'level `{self.level}`')
            return nag

        # Read the mask
        mask = nag[self.level][self.key]

        # Ensure the mask is a boolean tensor
        dtype = mask.dtype
        if dtype != torch.bool:
            if self.strict:
                raise ValueError(
                    f'`{self.key}` attribute has dtype={dtype} but '
                    f'dtype=torch.bool was expected')
            return nag

        # Ensure the mask size matches
        expected_size = torch.Size((nag[self.level].num_nodes,))
        actual_size = mask.shape
        if expected_size != actual_size:
            if self.strict:
                raise ValueError(
                    f'`{self.key}` attribute has shape={actual_size} but '
                    f'shape={expected_size} was expected')
            return nag

        # Call NAG.select using the mask on the `level` nodes
        mask = ~mask if self.negation else mask
        nag = nag.select(self.level, torch.where(mask)[0])

        # Remove the key if need be
        if self.delete_after:
            nag[self.level][self.key] = None

        return nag


class SelectColumns(Transform):
    """Select columns of an attribute based on their indices.

    :param key: str
        The Data attribute whose columns should be selected
    :param idx: int, Tensor or list
        The indices of the edge features to keep. If None, this
        transform will have no effect and edge features will be left
        untouched
    """

    def __init__(self, key=None, idx=None):
        assert key is not None, f"A Data key must be specified"
        self.key = key
        self.idx = tensor_idx(idx) if idx is not None else None

    def _process(self, data):
        if self.idx is None or getattr(data, self.key, None) is None:
            return data
        idx = tensor_idx(torch.as_tensor(self.idx, device=data.device))
        data[self.key] = data[self.key][:, idx]
        return data


class NAGSelectColumns(Transform):
    """Select columns of an attribute based on their indices.

    :param level: int or str
        Level at which to select attributes. Can be an int or a str. If
        the latter, 'all' will apply on all levels, 'i+' will apply on
        level-i and above, 'i-' will apply on level-i and below
    :param key: str
        The Data attribute whose columns should be selected
    :param idx: int, Tensor or list
        The indices of the edge features to keep. If None, this
        transform will have no effect and edge features will be left
        untouched
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def __init__(self, level='all', key=None, idx=None):
        self.level = level
        self.key = key
        self.idx = idx

    def _process(self, nag):

        level_idx = [None] * nag.num_levels
        if isinstance(self.level, int):
            level_idx[self.level] = self.idx
        elif self.level == 'all':
            level_idx = [self.idx] * nag.num_levels
        elif self.level[-1] == '+':
            i = int(self.level[:-1])
            level_idx[i:] = [self.idx] * (nag.num_levels - i)
        elif self.level[-1] == '-':
            i = int(self.level[:-1])
            level_idx[:i] = [self.idx] * i
        else:
            raise ValueError(f'Unsupported level={self.level}')

        transforms = [SelectColumns(key=self.key, idx=idx) for idx in level_idx]

        for i_level in range(nag.num_levels):
            nag._list[i_level] = transforms[i_level](nag._list[i_level])

        return nag


class DropoutColumns(Transform):
    """Randomly set a Data attribute column to 0.

    :param p: float
        Probability of a column to be dropped
    :param key: str
        The Data attribute whose columns should be selected
    :param inplace: bool
        Whether the dropout should be performed directly on the input
        or on a copy of it
    :param to_mean: bool
        Whether the dropped values should be set to the mean of their
        corresponding column (dim=1) or to zero (default)
    """

    def __init__(self, p=0.5, key=None, inplace=False, to_mean=False):
        assert key is not None, f"A Data key must be specified"
        self.p = p
        self.key = key
        self.inplace = inplace
        self.to_mean = to_mean

    def _process(self, data):
        # Skip dropout if p <= 0
        if self.p <= 0:
            return data

        # Skip dropout if the attribute is not present in the input Data
        if getattr(data, self.key, None) is None:
            return data

        # Apply dropout on each column, inplace
        data[self.key] = dropout(
            data[self.key], p=self.p, dim=1, inplace=self.inplace,
            to_mean=self.to_mean)

        return data


class NAGDropoutColumns(Transform):
    """Randomly set a Data attribute column to 0.

    :param level: int or str
        Level at which to drop columns. Can be an int or a str. If
        the latter, 'all' will apply on all levels, 'i+' will apply on
        level-i and above, 'i-' will apply on level-i and below
    :param p: float
        Probability of a column to be dropped
    :param key: str
        The Data attribute whose columns should be selected
    :param inplace: bool
        Whether the dropout should be performed directly on the input
        or on a copy of it
    :param to_mean: bool
        Whether the dropped values should be set to the mean of their
        corresponding column (dim=1) or to zero (default)
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def __init__(
            self, level='all', p=0.5, key=None, inplace=False, to_mean=False):
        assert isinstance(level, int) or level == 'all' or level.endswith('-') \
               or level.endswith('+')
        self.level = level
        self.p = p
        self.key = key
        self.inplace = inplace
        self.to_mean = to_mean

    def _process(self, nag):
        # Skip dropout if p <= 0
        if self.p <= 0:
            return nag

        if isinstance(self.level, int):
            levels = [self.level]
        elif self.level == 'all':
            levels = range(0, nag.num_levels)
        elif self.level[-1] == '+':
            levels = range(int(self.level[:-1]), nag.num_levels)
        elif self.level[-1] == '-':
            levels = range(0, int(self.level[:-1]) + 1)
        else:
            return nag

        for i_level in levels:
            # Skip dropout if the attribute is not present in the Data
            if getattr(nag[i_level], self.key, None) is None:
                continue

            # Apply dropout on each column, inplace
            nag[i_level][self.key] = dropout(
                nag[i_level][self.key], p=self.p, dim=1, inplace=self.inplace,
                to_mean=self.to_mean)

        return nag


class DropoutRows(Transform):
    """Randomly set a Data attribute rows to 0.

    :param p: float
        Probability of a row to be dropped
    :param key: str
        The Data attribute whose rows should be selected
    :param inplace: bool
        Whether the dropout should be performed directly on the input
        or on a copy of it
    :param to_mean: bool
        Whether the dropped values should be set to the mean of their
        corresponding column (dim=1) or to zero (default)
    """

    def __init__(self, p=0.5, key=None, inplace=False, to_mean=False):
        assert key is not None, f"A Data key must be specified"
        self.p = p
        self.key = key
        self.inplace = inplace
        self.to_mean = to_mean


    def _process(self, data):
        # Skip dropout if p <= 0
        if self.p <= 0:
            return data

        # Skip dropout if the attribute is not present in the input Data
        if getattr(data, self.key, None) is None:
            return data

        # Apply dropout on each column, inplace
        data[self.key] = dropout(
            data[self.key], p=self.p, dim=0, inplace=self.inplace,
            to_mean=self.to_mean)

        return data


class NAGDropoutRows(Transform):
    """Randomly set a Data attribute rows to 0.

    :param level: int or str
        Level at which to drop rows. Can be an int or a str. If
        the latter, 'all' will apply on all levels, 'i+' will apply on
        level-i and above, 'i-' will apply on level-i and below
    :param p: float
        Probability of a row to be dropped
    :param key: str
        The Data attribute whose rows should be selected
    :param inplace: bool
        Whether the dropout should be performed directly on the input
        or on a copy of it
    :param to_mean: bool
        Whether the dropped values should be set to the mean of their
        corresponding column (dim=1) or to zero (default)
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def __init__(
            self, level='all', p=0.5, key=None, inplace=False, to_mean=False):
        assert isinstance(level, int) or level == 'all' or level.endswith('-') \
               or level.endswith('+')
        self.level = level
        self.p = p
        self.key = key
        self.inplace = inplace
        self.to_mean = to_mean

    def _process(self, nag):
        # Skip dropout if p <= 0
        if self.p <= 0:
            return nag

        if isinstance(self.level, int):
            levels = [self.level]
        elif self.level == 'all':
            levels = range(0, nag.num_levels)
        elif self.level[-1] == '+':
            levels = range(int(self.level[:-1]), nag.num_levels)
        elif self.level[-1] == '-':
            levels = range(0, int(self.level[:-1]) + 1)
        else:
            return nag

        for i_level in levels:
            # Skip dropout if the attribute is not present in the Data
            if getattr(nag[i_level], self.key, None) is None:
                continue

            # Apply dropout on each column, inplace
            nag[i_level][self.key] = dropout(
                nag[i_level][self.key], p=self.p, dim=0, inplace=self.inplace,
                to_mean=self.to_mean)

        return nag


class NAGJitterKey(Transform):
    """Add some gaussian noise to Data['key'] for all data in a NAG.

    :param key: str
        The attribute on which to apply jittering
    :param sigma: float or List(float)
        Standard deviation of the gaussian noise. A list may be passed
        to transform NAG levels with different parameters. Passing
        sigma <= 0 will prevent any jittering
    :param trunc: float or List(float)
        Standard deviation of the gaussian noise. A list may be passed
        to transform NAG levels with different parameters. Passing
        trunc <= 0 will not truncate the normal distribution
    :param strict: bool
        Whether an error should be raised if one of the input NAG levels
        does not have `key` attribute
    """

    _IN_TYPE = NAG
    _OUT_TYPE = NAG

    def __init__(self, key=None, sigma=0.01, trunc=0.05, strict=False):
        assert key is not None, "A key must be specified"
        assert isinstance(sigma, (int, float, list))
        assert isinstance(trunc, (int, float, list))
        self.key = key
        self.sigma = sigma
        self.trunc = trunc
        self.strict = strict

    def _process(self, nag):
        if not isinstance(self.sigma, list):
            sigma = [self.sigma] * nag.num_levels
        else:
            sigma = self.sigma

        if not isinstance(self.trunc, list):
            trunc = [self.trunc] * nag.num_levels
        else:
            trunc = self.trunc

        for i_level in range(nag.num_levels):

            if sigma[i_level] <= 0:
                continue

            if getattr(nag[i_level], self.key, None) is None:
                if self.strict:
                    raise ValueError(
                        f"Input data does not have any '{self.key} attribute")
                else:
                    continue

            if trunc[i_level] > 0:
                noise = torch.nn.init.trunc_normal_(
                    torch.empty_like(nag[i_level][self.key]),
                    mean=0.,
                    std=sigma[i_level],
                    a=-trunc[i_level],
                    b=trunc[i_level])
            else:
                noise = torch.randn_like(
                    nag[i_level][self.key]) * sigma[i_level]

            nag[i_level][self.key] += noise

        return nag