File size: 2,697 Bytes
e0c75d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
from .genconvit_ed import GenConViTED
from .genconvit_vae import GenConViTVAE
from torchvision import transforms

class GenConViT(nn.Module):

    def __init__(self, config, ed, vae, net, fp16):
        super(GenConViT, self).__init__()
        self.net = net
        self.fp16 = fp16
        self.model_ed = None
        self.model_vae = None

        if ed:
            try:
                self.model_ed = GenConViTED(config)
                self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu'))

                if 'state_dict' in self.checkpoint_ed:
                    self.model_ed.load_state_dict(self.checkpoint_ed['state_dict'])
                else:
                    self.model_ed.load_state_dict(self.checkpoint_ed)

                self.model_ed.eval()
                if self.fp16:
                    self.model_ed.half()
            except FileNotFoundError:
                if self.net == 'ed' or self.net == 'genconvit':
                    raise Exception(f"Error: weight/{ed}.pth file not found.")

        if vae:
            try:
                self.model_vae = GenConViTVAE(config)
                self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu'))

                if 'state_dict' in self.checkpoint_vae:
                    self.model_vae.load_state_dict(self.checkpoint_vae['state_dict'])
                else:
                    self.model_vae.load_state_dict(self.checkpoint_vae)
                    
                self.model_vae.eval()
                if self.fp16:
                    self.model_vae.half()
            except FileNotFoundError:
                if self.net == 'vae' or self.net == 'genconvit':
                    raise Exception(f"Error: weight/{vae}.pth file not found.")

    def forward(self, x, net=None):
        if net is None:
            net = self.net

        if net == 'ed' :
            if self.model_ed is None:
                raise RuntimeError("ED model (AE) is not loaded. Ensure weights were provided during initialization.")
            x = self.model_ed(x)
        elif net == 'vae':
            if self.model_vae is None:
                raise RuntimeError("VAE model is not loaded. Ensure weights were provided during initialization.")
            x,_ = self.model_vae(x)
        else: # 'genconvit' or 'both'
            if self.model_ed is None or self.model_vae is None:
                raise RuntimeError("Both ED and VAE models must be loaded for 'genconvit' mode.")
            x1 = self.model_ed(x)
            x2,_ = self.model_vae(x)
            x =  torch.cat((x1, x2), dim=0) #(x1+x2)/2 #
        return x