| from transformers import PretrainedConfig |
| from typing import List |
|
|
| class UNet3DConfig(PretrainedConfig): |
| model_type = "UNet" |
| def __init__( |
| self, |
| in_ch=1, |
| out_ch=1, |
| init_features=64, |
| **kwargs): |
| self.in_ch = in_ch |
| self.out_ch = out_ch |
| self.init_features = init_features |
| super().__init__(**kwargs) |
|
|
| class UNetMSS3DConfig(PretrainedConfig): |
| model_type = "UNetMSS" |
| def __init__( |
| self, |
| in_ch=1, |
| out_ch=1, |
| init_features=64, |
| **kwargs): |
| self.in_ch = in_ch |
| self.out_ch = out_ch |
| self.init_features = init_features |
| super().__init__(**kwargs) |