Spaces:
Sleeping
Sleeping
File size: 904 Bytes
a12db03 a3ea780 a12db03 a3ea780 a12db03 a3ea780 a12db03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch.nn as nn
import torch
class CNN(nn.Module):
def __init__(self, n_classes: int = 50) -> None:
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 24, kernel_size=(5, 5)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(4, 2), stride=(4, 2)),
nn.Conv2d(24, 48, kernel_size=(5, 5)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(4, 2), stride=(4, 2)),
nn.Conv2d(48, 48, kernel_size=(5, 5)),
nn.ReLU(),
)
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(2400, 64),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(64, n_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = x.flatten(1)
return self.classifier(x)
|