| from transformers import PreTrainedModel | |
| from .w_net_3d import WNet3dUNet, WNet3dAttUNet, WNet3dUNetMSS | |
| from .WNetConfigs import WNet3DConfig, AttWNet3DConfig, WNetMSS3DConfig | |
| class WNet3D(PreTrainedModel): | |
| config_class = WNet3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = WNet3dUNet( | |
| in_ch=config.in_ch, | |
| out_ch=config.out_ch, | |
| init_features=config.init_features) | |
| def forward(self, x): | |
| return self.model(x) | |
| class AttWNet3D(PreTrainedModel): | |
| config_class = AttWNet3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = WNet3dAttUNet( | |
| in_ch=config.in_ch, | |
| out_ch=config.out_ch, | |
| init_features=config.init_features) | |
| def forward(self, x): | |
| return self.model(x) | |
| class WNetMSS3D(PreTrainedModel): | |
| config_class = WNetMSS3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = WNet3dUNetMSS( | |
| in_ch=config.in_ch, | |
| out_ch=config.out_ch, | |
| init_features=config.init_features) | |
| def forward(self, x): | |
| return self.model(x) |