Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .genconvit_ed import GenConViTED | |
| from .genconvit_vae import GenConViTVAE | |
| from torchvision import transforms | |
| class GenConViT(nn.Module): | |
| def __init__(self, config, ed, vae, net, fp16): | |
| super(GenConViT, self).__init__() | |
| self.net = net | |
| self.fp16 = fp16 | |
| self.model_ed = None | |
| self.model_vae = None | |
| if ed: | |
| try: | |
| self.model_ed = GenConViTED(config) | |
| self.checkpoint_ed = torch.load(f'weight/{ed}.pth', map_location=torch.device('cpu')) | |
| if 'state_dict' in self.checkpoint_ed: | |
| self.model_ed.load_state_dict(self.checkpoint_ed['state_dict']) | |
| else: | |
| self.model_ed.load_state_dict(self.checkpoint_ed) | |
| self.model_ed.eval() | |
| if self.fp16: | |
| self.model_ed.half() | |
| except FileNotFoundError: | |
| if self.net == 'ed' or self.net == 'genconvit': | |
| raise Exception(f"Error: weight/{ed}.pth file not found.") | |
| if vae: | |
| try: | |
| self.model_vae = GenConViTVAE(config) | |
| self.checkpoint_vae = torch.load(f'weight/{vae}.pth', map_location=torch.device('cpu')) | |
| if 'state_dict' in self.checkpoint_vae: | |
| self.model_vae.load_state_dict(self.checkpoint_vae['state_dict']) | |
| else: | |
| self.model_vae.load_state_dict(self.checkpoint_vae) | |
| self.model_vae.eval() | |
| if self.fp16: | |
| self.model_vae.half() | |
| except FileNotFoundError: | |
| if self.net == 'vae' or self.net == 'genconvit': | |
| raise Exception(f"Error: weight/{vae}.pth file not found.") | |
| def forward(self, x, net=None): | |
| if net is None: | |
| net = self.net | |
| if net == 'ed' : | |
| if self.model_ed is None: | |
| raise RuntimeError("ED model (AE) is not loaded. Ensure weights were provided during initialization.") | |
| x = self.model_ed(x) | |
| elif net == 'vae': | |
| if self.model_vae is None: | |
| raise RuntimeError("VAE model is not loaded. Ensure weights were provided during initialization.") | |
| x,_ = self.model_vae(x) | |
| else: # 'genconvit' or 'both' | |
| if self.model_ed is None or self.model_vae is None: | |
| raise RuntimeError("Both ED and VAE models must be loaded for 'genconvit' mode.") | |
| x1 = self.model_ed(x) | |
| x2,_ = self.model_vae(x) | |
| x = torch.cat((x1, x2), dim=0) #(x1+x2)/2 # | |
| return x | |