soumickmj's picture
Upload GPReconResNet
0a4a687 verified
from transformers import PreTrainedModel
from .GP_UNet import GP_UNet
from .GP_ReconResNet import GP_ReconResNet
from .GP_ShuffleUNet import GP_ShuffleUNet
from .GPModelConfigs import GPUNetConfig, GPReconResNetConfig, GPShuffleUNetConfig
class GPUNet(PreTrainedModel):
config_class = GPUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_UNet(
in_channels=config.in_channels,
n_classes=config.n_classes,
depth=config.depth,
wf=config.wf,
padding=config.padding,
batch_norm=config.batch_norm,
up_mode=config.up_mode,
dropout=config.dropout,
Relu=config.Relu,
out_act=config.out_act)
def forward(self, x):
return self.model(x)
class GPReconResNet(PreTrainedModel):
config_class = GPReconResNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_ReconResNet(
in_channels=config.in_channels,
n_classes=config.n_classes,
res_blocks=config.res_blocks,
starting_nfeatures=config.starting_nfeatures,
updown_blocks=config.updown_blocks,
is_relu_leaky=config.is_relu_leaky,
do_batchnorm=config.do_batchnorm,
res_drop_prob=config.res_drop_prob,
out_act=config.out_act,
forwardV=config.forwardV,
upinterp_algo=config.upinterp_algo,
post_interp_convtrans=config.post_interp_convtrans,
is3D=config.is3D)
def forward(self, x):
return self.model(x)
class GPShuffleUNet(PreTrainedModel):
config_class = GPShuffleUNetConfig
def __init__(self, config):
super().__init__(config)
self.model = GP_ShuffleUNet(
d=config.d,
in_ch=config.in_ch,
num_features=config.num_features,
n_levels=config.n_levels,
out_ch=config.out_ch,
kernel_size=config.kernel_size,
stride=config.stride,
dropout=config.dropout,
out_act=config.out_act)
def forward(self, x):
return self.model(x)