Upload model
Browse files- config.json +1 -0
- configuration_spice_cnn.py +2 -0
- modeling_spice_cnn.py +1 -1
- pytorch_model.bin +1 -1
config.json
CHANGED
|
@@ -8,6 +8,7 @@
|
|
| 8 |
},
|
| 9 |
"dropout_rate": 0.2,
|
| 10 |
"hidden_size": 128,
|
|
|
|
| 11 |
"kernel_size": 3,
|
| 12 |
"model_type": "spicecnn",
|
| 13 |
"num_classes": 10,
|
|
|
|
| 8 |
},
|
| 9 |
"dropout_rate": 0.2,
|
| 10 |
"hidden_size": 128,
|
| 11 |
+
"in_channels": 1,
|
| 12 |
"kernel_size": 3,
|
| 13 |
"model_type": "spicecnn",
|
| 14 |
"num_classes": 10,
|
configuration_spice_cnn.py
CHANGED
|
@@ -26,6 +26,7 @@ class SpiceCNNConfig(PretrainedConfig):
|
|
| 26 |
|
| 27 |
def __init__(
|
| 28 |
self,
|
|
|
|
| 29 |
num_classes: int = 10,
|
| 30 |
dropout_rate: float = 0.2,
|
| 31 |
hidden_size: int = 128,
|
|
@@ -37,6 +38,7 @@ class SpiceCNNConfig(PretrainedConfig):
|
|
| 37 |
**kwargs
|
| 38 |
):
|
| 39 |
super().__init__(**kwargs)
|
|
|
|
| 40 |
self.num_classes = num_classes
|
| 41 |
self.dropout_rate = dropout_rate
|
| 42 |
self.hidden_size = hidden_size
|
|
|
|
| 26 |
|
| 27 |
def __init__(
|
| 28 |
self,
|
| 29 |
+
in_channels: int = 3,
|
| 30 |
num_classes: int = 10,
|
| 31 |
dropout_rate: float = 0.2,
|
| 32 |
hidden_size: int = 128,
|
|
|
|
| 38 |
**kwargs
|
| 39 |
):
|
| 40 |
super().__init__(**kwargs)
|
| 41 |
+
self.in_channels = in_channels
|
| 42 |
self.num_classes = num_classes
|
| 43 |
self.dropout_rate = dropout_rate
|
| 44 |
self.hidden_size = hidden_size
|
modeling_spice_cnn.py
CHANGED
|
@@ -12,7 +12,7 @@ class SpiceCNNModelForImageClassification(PreTrainedModel):
|
|
| 12 |
super().__init__(config)
|
| 13 |
layers = [
|
| 14 |
nn.Conv2d(
|
| 15 |
-
|
| 16 |
16,
|
| 17 |
kernel_size=config.kernel_size,
|
| 18 |
stride=config.stride,
|
|
|
|
| 12 |
super().__init__(config)
|
| 13 |
layers = [
|
| 14 |
nn.Conv2d(
|
| 15 |
+
config.in_channels,
|
| 16 |
16,
|
| 17 |
kernel_size=config.kernel_size,
|
| 18 |
stride=config.stride,
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 830347
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2bd33ae4006b549f8ef4839e107525981d190d4922a7236e5de3a59190450a1
|
| 3 |
size 830347
|