File size: 2,239 Bytes
0a4a687 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from transformers import PreTrainedModel
from .GP_UNet import GP_UNet
from .GP_ReconResNet import GP_ReconResNet
from .GP_ShuffleUNet import GP_ShuffleUNet
from .GPModelConfigs import GPUNetConfig, GPReconResNetConfig, GPShuffleUNetConfig
class GPUNet(PreTrainedModel):
config_class = GPUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_UNet(
in_channels=config.in_channels,
n_classes=config.n_classes,
depth=config.depth,
wf=config.wf,
padding=config.padding,
batch_norm=config.batch_norm,
up_mode=config.up_mode,
dropout=config.dropout,
Relu=config.Relu,
out_act=config.out_act)
def forward(self, x):
return self.model(x)
class GPReconResNet(PreTrainedModel):
config_class = GPReconResNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_ReconResNet(
in_channels=config.in_channels,
n_classes=config.n_classes,
res_blocks=config.res_blocks,
starting_nfeatures=config.starting_nfeatures,
updown_blocks=config.updown_blocks,
is_relu_leaky=config.is_relu_leaky,
do_batchnorm=config.do_batchnorm,
res_drop_prob=config.res_drop_prob,
out_act=config.out_act,
forwardV=config.forwardV,
upinterp_algo=config.upinterp_algo,
post_interp_convtrans=config.post_interp_convtrans,
is3D=config.is3D)
def forward(self, x):
return self.model(x)
class GPShuffleUNet(PreTrainedModel):
config_class = GPShuffleUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_ShuffleUNet(
d=config.d,
in_ch=config.in_ch,
num_features=config.num_features,
n_levels=config.n_levels,
out_ch=config.out_ch,
kernel_size=config.kernel_size,
stride=config.stride,
dropout=config.dropout,
out_act=config.out_act)
def forward(self, x):
return self.model(x) |