| from transformers import PretrainedConfig
|
| from typing import Literal
|
|
|
|
|
| class BirdMAEConfig(PretrainedConfig):
|
| """This represents the Bird-MAE-Base config from the original paper"""
|
| _auto_class = "AutoConfig"
|
|
|
| def __init__(
|
| self,
|
| img_size_x: int = 512,
|
| img_size_y: int = 128,
|
| patch_size: int = 16,
|
| in_chans: int = 1,
|
| embed_dim: int = 768,
|
| depth: int = 12,
|
| num_heads: int = 12,
|
| mlp_ratio: int = 4,
|
| pos_trainable: bool = False,
|
| qkv_bias: bool = True,
|
| qk_norm: bool = False,
|
| init_values: float = None,
|
| drop_rate: float = 0.0,
|
| norm_layer_eps: float = 1e-6,
|
| global_pool: Literal["cls", "mean"] | None = "mean",
|
| **kwargs
|
| ):
|
| super().__init__(**kwargs)
|
|
|
| self.img_size_x = img_size_x
|
| self.img_size_y = img_size_y
|
| self.patch_size = patch_size
|
| self.in_chans = in_chans
|
| self.embed_dim = embed_dim
|
| self.depth = depth
|
| self.num_heads = num_heads
|
| self.mlp_ratio = mlp_ratio
|
| self.pos_trainable = pos_trainable
|
|
|
| self.qkv_bias = qkv_bias
|
| self.qk_norm = qk_norm
|
| self.init_values = init_values
|
| self.drop_rate = drop_rate
|
| self.pos_drop_rate = drop_rate
|
| self.attn_drop_rate = drop_rate
|
| self.drop_path_rate = drop_rate
|
| self.proj_drop_rate = drop_rate
|
| self.norm_layer_eps = norm_layer_eps
|
| self.global_pool = global_pool
|
|
|
|
|
| self.num_patches_x = img_size_x // patch_size
|
| self.num_patches_y = img_size_y // patch_size
|
| self.num_patches = self.num_patches_x * self.num_patches_y
|
| self.num_tokens = self.num_patches + 1 |