sahar-yaccov commited on
Commit
bb7018b
verified
1 Parent(s): d3a3d24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -13,28 +13,37 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  # Model Architecture 诪讚讜讬拽转
14
  # -------------------------------
15
  model = nn.Sequential(
16
- nn.Conv2d(3, 16, kernel_size=3, padding=1),
 
17
  nn.ReLU(),
18
- nn.MaxPool2d(2, 2),
19
-
20
- nn.Conv2d(16, 32, kernel_size=3, padding=1),
 
21
  nn.ReLU(),
22
- nn.MaxPool2d(2, 2),
23
-
24
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
 
25
  nn.ReLU(),
26
- nn.MaxPool2d(2, 2),
27
-
28
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
 
29
  nn.ReLU(),
30
- nn.MaxPool2d(2, 2),
31
-
32
- nn.Flatten(),
33
- nn.Dropout(0.6),
34
- nn.Linear(128*14*14, 128), # 诪转讗讬诐 诇-224x224 拽诇讟
 
 
 
 
 
35
  nn.ReLU(),
36
- nn.Dropout(0.4),
37
- nn.Linear(128, 2)
38
  ).to(device)
39
 
40
  model.load_state_dict(torch.load("cnn_model.pth", map_location=device))
 
13
  # Model Architecture 诪讚讜讬拽转
14
  # -------------------------------
15
  model = nn.Sequential(
16
+ nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
17
+ nn.BatchNorm2d(16),
18
  nn.ReLU(),
19
+ nn.MaxPool2d(kernel_size=2, stride=2),
20
+
21
+ nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
22
+ nn.BatchNorm2d(32),
23
  nn.ReLU(),
24
+ nn.MaxPool2d(kernel_size=2, stride=2),
25
+
26
+ nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
27
+ nn.BatchNorm2d(64),
28
  nn.ReLU(),
29
+ nn.MaxPool2d(kernel_size=2, stride=2),
30
+
31
+ nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
32
+ nn.BatchNorm2d(128),
33
  nn.ReLU(),
34
+ nn.MaxPool2d(kernel_size=2, stride=2),
35
+
36
+ nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
37
+ nn.BatchNorm2d(256),
38
+ nn.ReLU(),
39
+ nn.MaxPool2d(kernel_size=2, stride=2),
40
+
41
+ nn.Flatten(start_dim=1, end_dim=-1),
42
+ nn.Dropout(p=0.5),
43
+ nn.Linear(in_features=12544, out_features=256),
44
  nn.ReLU(),
45
+ nn.Dropout(p=0.5),
46
+ nn.Linear(in_features=256, out_features=2)
47
  ).to(device)
48
 
49
  model.load_state_dict(torch.load("cnn_model.pth", map_location=device))