| from transformers import PreTrainedModel |
| from .unet3d import U_Net, U_Net_DeepSup |
| from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig |
|
|
| class UNet3D(PreTrainedModel): |
| config_class = UNet3DConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = U_Net( |
| in_ch=config.in_ch, |
| out_ch=config.out_ch, |
| init_features=config.init_features) |
| def forward(self, x): |
| return self.model(x) |
| |
| class UNetMSS3D(PreTrainedModel): |
| config_class = UNetMSS3DConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = U_Net_DeepSup( |
| in_ch=config.in_ch, |
| out_ch=config.out_ch, |
| init_features=config.init_features) |
| def forward(self, x): |
| return self.model(x) |