File size: 773 Bytes
0f5deb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from .reconstruction_model import Reconstruction3DEncoder, Reconstruction3DDecoder

class convAE(torch.nn.Module):
    def __init__(self):  # for reconstruction
        super(convAE, self).__init__()

        self.reconstruction = True

        # self.encoder = Reconstruction3DEncoder(chnum_in=1)  # black and white
        # self.decoder = Reconstruction3DDecoder(chnum_in=1)  # black and white
        self.encoder = Reconstruction3DEncoder(chnum_in=3)  # RGB
        self.decoder = Reconstruction3DDecoder(chnum_in=3)  # RGB

    def forward(self, x):
        # print(x.shape)
        fea = self.encoder(x)
        # print(fea.shape)
        output = self.decoder(fea.clone())
        # print(output.shape)

        return output