File size: 3,016 Bytes
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self, config, gk, gs, gf, gp):
        super(Generator, self).__init__()
        self.config = config
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.no_layers = len(gk)
        for lay, (k, s, p) in enumerate(zip(gk, gs, gp)):
            if lay < self.no_layers - 2:
                self.convs.append(
                    nn.ConvTranspose2d(gf[lay], gf[lay + 1], k, s, p, bias=False)
                )
            else:
                self.convs.append(
                    nn.Conv2d(
                        gf[lay],
                        gf[lay + 1],
                        k,
                        s,
                        p,
                        bias=False,
                        padding_mode="reflect",
                    )
                )
            self.bns.append(nn.BatchNorm2d(gf[lay + 1]))

    def forward(self, x: torch.Tensor):
        count = 0
        # layers = []
        for conv, bn in zip(self.convs[:-1], self.bns[:-1]):
            if count < self.no_layers - 2:
                x = conv(x)
                x = bn(x)
                x = F.relu_(x)
            else:
                x = conv(x)
                x = F.interpolate(
                    x,
                    [x.shape[-2] * 2 + 2, x.shape[-1] * 2 + 2],
                    mode="bilinear",
                    align_corners=False,
                )
                x = bn(x)
                x = F.relu_(x)
            count += 1
        if self.config.image_type == "n-phase":
            out = torch.softmax(self.convs[-1](x), dim=1)
        else:
            out = torch.sigmoid(self.convs[-1](x))
        return out  # bs x n x imsize x imsize x imsize


class Discriminator(nn.Module):
    def __init__(self, dk, ds, dp, df):
        super(Discriminator, self).__init__()
        self.convs = nn.ModuleList()
        for lay, (k, s, p) in enumerate(zip(dk, ds, dp)):
            self.convs.append(nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False))

    def forward(self, x):
        for conv in self.convs[:-1]:
            x = F.relu_(conv(x))

        x = self.convs[-1](x)  # bs x 1 x 1
        return x


def make_nets(config, training=True):
    """Creates Generator and Discriminator class objects from params either loaded from config object or params file.

    :param config: a Config class object
    :type config: Config
    :param training: if training is True, params are loaded from Config object. If False, params are loaded from file, defaults to True
    :type training: bool, optional
    :return: Discriminator and Generator class objects
    :rtype: Discriminator, Generator
    """

    # save/load params
    if training:
        config.save()
    else:
        config.load()

    dk, ds, df, dp, gk, gs, gf, gp = config.get_net_params()

    # Make nets
    return Discriminator(dk, ds, dp, df), Generator(config, gk, gs, gf, gp)