| from transformers import PretrainedConfig |
|
|
| """Spice CNN model configuration""" |
|
|
| SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
| "spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json" |
| } |
|
|
|
|
| |
| class SpiceCNNConfig(PretrainedConfig): |
| """ |
| This is the configuration class to store the configuration of a [`SpiceCNNModel`]. |
| It is used to instantiate an SpiceCNN model according to the specified arguments, |
| defining the model architecture. Instantiating a configuration with the defaults |
| will yield a similar configuration to that of the SpiceCNN |
| [spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base) |
| architecture. |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control |
| the model outputs. Read the documentation from [`PretrainedConfig`] for more |
| information. |
| """ |
|
|
| model_type = "spicecnn" |
|
|
| def __init__( |
| self, |
| in_channels: int = 3, |
| num_classes: int = 10, |
| dropout_rate: float = 0.4, |
| hidden_size: int = 128, |
| num_filters: int = 16, |
| kernel_size: int = 3, |
| stride: int = 1, |
| padding: int = 1, |
| pooling_size: int = 2, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.in_channels = in_channels |
| self.num_classes = num_classes |
| self.dropout_rate = dropout_rate |
| self.hidden_size = hidden_size |
| self.num_filters = num_filters |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.pooling_size = pooling_size |
|
|