| from transformers import PretrainedConfig | |
| from typing import List | |
| class MnistConfig(PretrainedConfig): | |
| # since we have an image classification task | |
| # we need to put a model type that is close to our task | |
| # don't worry this will not affect our model | |
| model_type = "MobileNetV1" | |
| def __init__( | |
| self, | |
| conv1=10, | |
| conv2=20, | |
| **kwargs): | |
| self.conv1 = conv1 | |
| self.conv2 = conv2 | |
| super().__init__(**kwargs) | |