| from transformers import PretrainedConfig |
| from typing import List |
|
|
|
|
| class AugViTConfig(PretrainedConfig): |
| model_type = "augvit" |
|
|
| def __init__( |
| self, |
| image_size: int = 224, |
| patch_size: int = 32, |
| num_classes: int = 1000, |
| dim: int = 128, |
| depth: int = 2, |
| heads: int = 16, |
| mlp_dim: int = 256, |
| dropout: int = 0.1, |
| emb_dropout: int = 0.1, |
| num_channels:int=3, |
| **kwargs, |
| ): |
|
|
| self.image_size = image_size |
| self.patch_size = patch_size |
| self.num_classes = num_classes |
| self.dim = dim |
| self.depth = depth |
| self.heads = heads |
| self.mlp_dim = mlp_dim |
| self.dropout = dropout |
| self.emb_dropout = emb_dropout |
| self.num_channels=num_channels |
| super().__init__(**kwargs) |