File size: 1,580 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn

from typing import Iterable

from src.simulation.component import Component

################################################################################
# Wrap preprocessing stages and apply sequentially
################################################################################


class Preprocessor(nn.Module):
    """

    Wrapper for sequential application of preprocessing stages. Allows for

    straight-through gradient estimation. Because random parameter sampling is

    not required, all modules are only required to be Component objects

    """
    def __init__(self, *args):
        super().__init__()

        stages = []

        if len(args) == 1 and isinstance(args[0], Iterable):
            for stage in args[0]:
                assert isinstance(stage, Component), \
                    "Arguments must be Component objects"
                stages.append(stage)
        else:
            for stage in args:
                assert isinstance(stage, Component), \
                    "Arguments must be Component objects"
                stages.append(stage)

        self.stages = nn.ModuleList(stages)

    def forward(self, x: torch.Tensor):

        # apply in sequence
        for stage in self.stages:

            if stage.compute_grad:
                x = stage(x)

            else:
                # allow straight-through gradient estimation on backward pass
                output = stage(x)
                x = x + (output-x).detach()

        return x