| from transformers import PreTrainedModel | |
| from .GP_UNet import GP_UNet | |
| from .GP_ReconResNet import GP_ReconResNet | |
| from .GP_ShuffleUNet import GP_ShuffleUNet | |
| from .GPModelConfigs import GPUNetConfig, GPReconResNetConfig, GPShuffleUNetConfig | |
| class GPUNet(PreTrainedModel): | |
| config_class = GPUNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = GP_UNet( | |
| in_channels=config.in_channels, | |
| n_classes=config.n_classes, | |
| depth=config.depth, | |
| wf=config.wf, | |
| padding=config.padding, | |
| batch_norm=config.batch_norm, | |
| up_mode=config.up_mode, | |
| dropout=config.dropout, | |
| Relu=config.Relu, | |
| out_act=config.out_act) | |
| def forward(self, x): | |
| return self.model(x) | |
| class GPReconResNet(PreTrainedModel): | |
| config_class = GPReconResNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = GP_ReconResNet( | |
| in_channels=config.in_channels, | |
| n_classes=config.n_classes, | |
| res_blocks=config.res_blocks, | |
| starting_nfeatures=config.starting_nfeatures, | |
| updown_blocks=config.updown_blocks, | |
| is_relu_leaky=config.is_relu_leaky, | |
| do_batchnorm=config.do_batchnorm, | |
| res_drop_prob=config.res_drop_prob, | |
| out_act=config.out_act, | |
| forwardV=config.forwardV, | |
| upinterp_algo=config.upinterp_algo, | |
| post_interp_convtrans=config.post_interp_convtrans, | |
| is3D=config.is3D) | |
| def forward(self, x): | |
| return self.model(x) | |
| class GPShuffleUNet(PreTrainedModel): | |
| config_class = GPShuffleUNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = GP_ShuffleUNet( | |
| d=config.d, | |
| in_ch=config.in_ch, | |
| num_features=config.num_features, | |
| n_levels=config.n_levels, | |
| out_ch=config.out_ch, | |
| kernel_size=config.kernel_size, | |
| stride=config.stride, | |
| dropout=config.dropout, | |
| out_act=config.out_act) | |
| def forward(self, x): | |
| return self.model(x) |