from transformers import PreTrainedModel from .GP_UNet import GP_UNet from .GP_ReconResNet import GP_ReconResNet from .GP_ShuffleUNet import GP_ShuffleUNet from .GPModelConfigs import GPUNetConfig, GPReconResNetConfig, GPShuffleUNetConfig class GPUNet(PreTrainedModel): config_class = GPUNetConfig def __init__(self, config): super().__init__(config) self.model = GP_UNet( in_channels=config.in_channels, n_classes=config.n_classes, depth=config.depth, wf=config.wf, padding=config.padding, batch_norm=config.batch_norm, up_mode=config.up_mode, dropout=config.dropout, Relu=config.Relu, out_act=config.out_act) def forward(self, x): return self.model(x) class GPReconResNet(PreTrainedModel): config_class = GPReconResNetConfig def __init__(self, config): super().__init__(config) self.model = GP_ReconResNet( in_channels=config.in_channels, n_classes=config.n_classes, res_blocks=config.res_blocks, starting_nfeatures=config.starting_nfeatures, updown_blocks=config.updown_blocks, is_relu_leaky=config.is_relu_leaky, do_batchnorm=config.do_batchnorm, res_drop_prob=config.res_drop_prob, out_act=config.out_act, forwardV=config.forwardV, upinterp_algo=config.upinterp_algo, post_interp_convtrans=config.post_interp_convtrans, is3D=config.is3D) def forward(self, x): return self.model(x) class GPShuffleUNet(PreTrainedModel): config_class = GPShuffleUNetConfig def __init__(self, config): super().__init__(config) self.model = GP_ShuffleUNet( d=config.d, in_ch=config.in_ch, num_features=config.num_features, n_levels=config.n_levels, out_ch=config.out_ch, kernel_size=config.kernel_size, stride=config.stride, dropout=config.dropout, out_act=config.out_act) def forward(self, x): return self.model(x)