| from transformers import PretrainedConfig | |
| class M3DCLIPConfig(PretrainedConfig): | |
| model_type = "m3d_clip" | |
| def __init__( | |
| self, | |
| language_model_name_or_path: str = 'bert-base-uncased', | |
| local_loss: bool = False, | |
| gather_loss: bool = True, | |
| in_channels: int = 1, | |
| img_size: tuple = (32, 256, 256), | |
| patch_size: tuple = (4, 16, 16), | |
| hidden_size: int = 768, | |
| mlp_dim: int = 3072, | |
| num_layers: int = 12, | |
| num_heads: int = 12, | |
| pos_embed: str = "perceptron", | |
| dropout_rate: float = 0, | |
| spatial_dims: int = 3, | |
| max_text_len: int = 128, | |
| vocab_size: int = 30522, | |
| **kwargs, | |
| ): | |
| self.language_model_name_or_path = language_model_name_or_path | |
| self.in_channels = in_channels | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.hidden_size = hidden_size | |
| self.mlp_dim = mlp_dim | |
| self.num_layers = num_layers | |
| self.num_heads = num_heads | |
| self.pos_embed = pos_embed | |
| self.dropout_rate = dropout_rate | |
| self.spatial_dims = spatial_dims | |
| self.local_loss = local_loss | |
| self.gather_loss = gather_loss | |
| self.max_text_len = max_text_len | |
| self.vocab_size = vocab_size | |
| super().__init__(**kwargs) | |