import torch import torch.nn as nn from .genconvit_ed import GenConViTED from .genconvit_vae import GenConViTVAE from torchvision import transforms class GenConViT(nn.Module): def __init__(self, config, ed, vae, net, fp16): super(GenConViT, self).__init__() self.net = net self.fp16 = fp16 self.model_ed = None self.model_vae = None if ed: try: self.model_ed = GenConViTED(config) self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu')) if 'state_dict' in self.checkpoint_ed: self.model_ed.load_state_dict(self.checkpoint_ed['state_dict']) else: self.model_ed.load_state_dict(self.checkpoint_ed) self.model_ed.eval() if self.fp16: self.model_ed.half() except FileNotFoundError: if self.net == 'ed' or self.net == 'genconvit': raise Exception(f"Error: weight/{ed}.pth file not found.") if vae: try: self.model_vae = GenConViTVAE(config) self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu')) if 'state_dict' in self.checkpoint_vae: self.model_vae.load_state_dict(self.checkpoint_vae['state_dict']) else: self.model_vae.load_state_dict(self.checkpoint_vae) self.model_vae.eval() if self.fp16: self.model_vae.half() except FileNotFoundError: if self.net == 'vae' or self.net == 'genconvit': raise Exception(f"Error: weight/{vae}.pth file not found.") def forward(self, x, net=None): if net is None: net = self.net if net == 'ed' : if self.model_ed is None: raise RuntimeError("ED model (AE) is not loaded. Ensure weights were provided during initialization.") x = self.model_ed(x) elif net == 'vae': if self.model_vae is None: raise RuntimeError("VAE model is not loaded. Ensure weights were provided during initialization.") x,_ = self.model_vae(x) else: # 'genconvit' or 'both' if self.model_ed is None or self.model_vae is None: raise RuntimeError("Both ED and VAE models must be loaded for 'genconvit' mode.") x1 = self.model_ed(x) x2,_ = self.model_vae(x) x = torch.cat((x1, x2), dim=0) #(x1+x2)/2 # return x