Gillie2004 commited on
Commit
6ef34db
·
verified ·
1 Parent(s): 0f1d78b

Upload 3 files

Browse files
src/__pycache__/cnn_model.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
src/cat_cnn.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5848645750c42c14efbf3510a02656df894a717bc6d531822ba159e1fcbb37c5
3
+ size 67505560
src/cnn_model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class CatBreedCNN(nn.Module):
4
+ def __init__(self, num_classes):
5
+ super().__init__()
6
+ self.net = nn.Sequential(
7
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
8
+ nn.BatchNorm2d(32),
9
+ nn.ReLU(),
10
+ nn.MaxPool2d(2),
11
+
12
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
13
+ nn.BatchNorm2d(64),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(2),
16
+
17
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
18
+ nn.BatchNorm2d(128),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2),
21
+
22
+ nn.Flatten(),
23
+ nn.Linear(128 * 16 * 16, 512),
24
+ nn.ReLU(),
25
+ nn.Dropout(0.5),
26
+
27
+ nn.Linear(512, num_classes)
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.net(x)