| |
| from transformers import PretrainedConfig |
|
|
| class SMSelectiveViTConfig(PretrainedConfig): |
| model_type = "softmasked_selective_vit" |
|
|
| def __init__( |
| self, |
| image_size=224, |
| patch_size=16, |
| num_classes=1000, |
| embed_dim=768, |
| atten_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_dim=3072, |
| channels=3, |
| dropout=0.0, |
| drop_path=0.0, |
| attention_scale=0.0, |
| mask_threshold=0.0, |
| patch_drop=0.0, |
| use_distil_token=False, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| |
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.num_classes = num_classes |
| self.embed_dim = embed_dim |
| self.atten_dim = atten_dim |
| self.depth = depth |
| self.num_heads = num_heads |
| self.mlp_dim = mlp_dim |
| self.channels = channels |
| self.dropout = dropout |
| self.drop_path = drop_path |
| self.attention_scale = attention_scale |
| self.mask_threshold = mask_threshold |
| self.patch_drop = patch_drop |
| self.use_distil_token = use_distil_token |
|
|
|
|