File size: 1,564 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn

from typing import Iterable

from src.simulation.effect import Effect

################################################################################
# Wrap effects units to apply in sequence
################################################################################


class Simulation(nn.Module):
    """

    Wrapper for sequential application of effects units. Allows for straight-

    through gradient estimation and random effect parameter sampling.

    """
    def __init__(self, *args):
        super().__init__()

        effects = []

        if len(args) == 1 and isinstance(args[0], Iterable):
            for effect in args[0]:
                assert isinstance(effect, Effect), \
                    "Arguments must be Effect objects"
                effects.append(effect)
        else:
            for effect in args:
                assert isinstance(effect, Effect), \
                    "Arguments must be Effect objects"
                effects.append(effect)

        self.effects = nn.ModuleList(effects)

    def forward(self, x: torch.Tensor):

        for effect in self.effects:

            if effect.compute_grad:
                x = effect(x)

            else:
                # allow straight-through gradient estimation on backward pass
                output = effect(x)
                x = x + (output-x).detach()

        return x

    def sample_params(self):

        for effect in self.effects:
            effect.sample_params()