Update train.py
Browse files
train.py
CHANGED
|
@@ -67,13 +67,12 @@ class InteractorImageClassification(Interactor):
|
|
| 67 |
in_channels=3,
|
| 68 |
num_classes=10,
|
| 69 |
d_model = 256,
|
| 70 |
-
num_tokens = 64,
|
| 71 |
num_layers=4,
|
| 72 |
|
| 73 |
|
| 74 |
):
|
| 75 |
num_patches = check_sizes(image_size, patch_size)
|
| 76 |
-
super().__init__(d_model,
|
| 77 |
self.patcher = nn.Conv2d(
|
| 78 |
in_channels, d_model, kernel_size=patch_size, stride=patch_size
|
| 79 |
)
|
|
|
|
| 67 |
in_channels=3,
|
| 68 |
num_classes=10,
|
| 69 |
d_model = 256,
|
|
|
|
| 70 |
num_layers=4,
|
| 71 |
|
| 72 |
|
| 73 |
):
|
| 74 |
num_patches = check_sizes(image_size, patch_size)
|
| 75 |
+
super().__init__(d_model, num_layers)
|
| 76 |
self.patcher = nn.Conv2d(
|
| 77 |
in_channels, d_model, kernel_size=patch_size, stride=patch_size
|
| 78 |
)
|