| from transformers import PretrainedConfig | |
| class UnetConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| encoder_name: str = "resnet18", | |
| num_classes: int = 16, | |
| input_channels: int = 1, | |
| decoder_channels: tuple = (1024, 512, 256, 128, 64), | |
| **kwargs | |
| ): | |
| self.encoder_name = encoder_name | |
| self.num_classes = num_classes | |
| self.input_channels = input_channels | |
| self.decoder_channels = decoder_channels | |
| super().__init__(**kwargs) | |