| from transformers import PretrainedConfig | |
| class SpatiospatialResNetConfig(PretrainedConfig): | |
| model_type = "SpatiospatialResNet" | |
| def __init__( | |
| self, | |
| mode="MC", | |
| num_classes=3, | |
| is_pretrained=True, | |
| dropout=0.3, | |
| **kwargs): | |
| self.mode = mode | |
| self.num_classes = num_classes | |
| self.is_pretrained = is_pretrained | |
| self.dropout = dropout | |
| super().__init__(**kwargs) |