SM-Selective-ViT-Base-224 / configuration_selectivevit.py
XAFT's picture
Upload model files
e01e400 verified
# configuration_my_model.py
from transformers import PretrainedConfig
class SMSelectiveViTConfig(PretrainedConfig):
model_type = "softmasked_selective_vit"
def __init__(
self,
image_size=224,
patch_size=16,
num_classes=1000,
embed_dim=768,
atten_dim=768,
depth=12,
num_heads=12,
mlp_dim=3072,
channels=3,
dropout=0.0,
drop_path=0.0,
attention_scale=0.0,
mask_threshold=0.0,
patch_drop=0.0,
use_distil_token=False,
**kwargs,
):
super().__init__(**kwargs)
# store everything as attributes (HF will save them in config.json)
self.image_size = image_size
self.patch_size = patch_size
self.num_classes = num_classes
self.embed_dim = embed_dim
self.atten_dim = atten_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.channels = channels
self.dropout = dropout
self.drop_path = drop_path
self.attention_scale = attention_scale
self.mask_threshold = mask_threshold
self.patch_drop = patch_drop
self.use_distil_token = use_distil_token