| """Configuration classes for RSP models compatible with transformers""" | |
| from transformers import PretrainedConfig | |
| class RSPResNetConfig(PretrainedConfig): | |
| """Configuration for RSP ResNet models""" | |
| model_type = "rsp_resnet" | |
| def __init__( | |
| self, | |
| block="Bottleneck", | |
| layers=[3, 4, 6, 3], | |
| image_size=224, | |
| num_channels=3, | |
| num_labels=51, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.block = block | |
| self.layers = layers | |
| self.image_size = image_size | |
| self.num_channels = num_channels | |
| self.num_labels = num_labels | |
| class RSPSwinConfig(PretrainedConfig): | |
| """Configuration for RSP Swin Transformer models""" | |
| model_type = "rsp_swin" | |
| def __init__( | |
| self, | |
| image_size=224, | |
| patch_size=4, | |
| num_channels=3, | |
| embed_dim=96, | |
| depths=[2, 2, 6, 2], | |
| num_heads=[3, 6, 12, 24], | |
| window_size=7, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| ape=False, | |
| patch_norm=True, | |
| num_labels=51, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.num_channels = num_channels | |
| self.embed_dim = embed_dim | |
| self.depths = depths | |
| self.num_heads = num_heads | |
| self.window_size = window_size | |
| self.mlp_ratio = mlp_ratio | |
| self.qkv_bias = qkv_bias | |
| self.ape = ape | |
| self.patch_norm = patch_norm | |
| self.num_labels = num_labels | |
| class RSPViTAEConfig(PretrainedConfig): | |
| """Configuration for RSP ViTAE models""" | |
| model_type = "rsp_vitae" | |
| def __init__( | |
| self, | |
| image_size=224, | |
| num_channels=3, | |
| stages=4, | |
| embed_dims=[64, 64, 128, 256], | |
| token_dims=[64, 128, 256, 512], | |
| downsample_ratios=[4, 2, 2, 2], | |
| NC_depth=[2, 2, 8, 2], | |
| NC_heads=[1, 2, 4, 8], | |
| RC_heads=[1, 1, 2, 4], | |
| NC_group=[1, 32, 64, 128], | |
| RC_group=[1, 16, 32, 64], | |
| mlp_ratio=4.0, | |
| num_labels=51, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.image_size = image_size | |
| self.num_channels = num_channels | |
| self.stages = stages | |
| self.embed_dims = embed_dims | |
| self.token_dims = token_dims | |
| self.downsample_ratios = downsample_ratios | |
| self.NC_depth = NC_depth | |
| self.NC_heads = NC_heads | |
| self.RC_heads = RC_heads | |
| self.NC_group = NC_group | |
| self.RC_group = RC_group | |
| self.mlp_ratio = mlp_ratio | |
| self.num_labels = num_labels | |