| | from transformers import PretrainedConfig |
| |
|
| | """Spice CNN model configuration""" |
| |
|
| | SPICE_CNN_PRETRAINED_CONFIG_ARCHIVE_MAP = { |
| | "spicecloud/spice-cnn-base": "https://huggingface.co/spice-cnn-base/resolve/main/config.json" |
| | } |
| |
|
| |
|
| | |
| | class SpiceCNNConfig(PretrainedConfig): |
| | """ |
| | This is the configuration class to store the configuration of a [`SpiceCNNModel`]. |
| | It is used to instantiate an SpiceCNN model according to the specified arguments, |
| | defining the model architecture. Instantiating a configuration with the defaults |
| | will yield a similar configuration to that of the SpiceCNN |
| | [spicecloud/spice-cnn-base](https://huggingface.co/spicecloud/spice-cnn-base) |
| | architecture. |
| | |
| | Configuration objects inherit from [`PretrainedConfig`] and can be used to control |
| | the model outputs. Read the documentation from [`PretrainedConfig`] for more |
| | information. |
| | """ |
| |
|
| | model_type = "spicecnn" |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | num_classes: int = 10, |
| | dropout_rate: float = 0.4, |
| | hidden_size: int = 128, |
| | num_filters: int = 16, |
| | kernel_size: int = 3, |
| | stride: int = 1, |
| | padding: int = 1, |
| | pooling_size: int = 2, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.in_channels = in_channels |
| | self.num_classes = num_classes |
| | self.dropout_rate = dropout_rate |
| | self.hidden_size = hidden_size |
| | self.num_filters = num_filters |
| | self.kernel_size = kernel_size |
| | self.stride = stride |
| | self.padding = padding |
| | self.pooling_size = pooling_size |
| |
|