Spaces:
Running
Running
Update pages/15_CNN.py
Browse files- pages/15_CNN.py +21 -10
pages/15_CNN.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.optim as optim
|
|
@@ -47,20 +47,31 @@ class CNN(nn.Module):
|
|
| 47 |
nn.BatchNorm2d(64),
|
| 48 |
nn.ReLU(),
|
| 49 |
nn.MaxPool2d(2))
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
self.drop = nn.Dropout2d(0.25)
|
| 52 |
self.fc2 = nn.Linear(600, 100)
|
| 53 |
self.fc3 = nn.Linear(100, 10)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def forward(self, x):
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
return out
|
| 64 |
|
| 65 |
model = CNN().to(device)
|
| 66 |
|
|
|
|
| 1 |
+
]import streamlit as st
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.optim as optim
|
|
|
|
| 47 |
nn.BatchNorm2d(64),
|
| 48 |
nn.ReLU(),
|
| 49 |
nn.MaxPool2d(2))
|
| 50 |
+
|
| 51 |
+
# Automatically determine the size of the flattened features after convolution and pooling
|
| 52 |
+
self._to_linear = None
|
| 53 |
+
self.convs(torch.randn(1, 3, 32, 32))
|
| 54 |
+
|
| 55 |
+
self.fc1 = nn.Linear(self._to_linear, 600)
|
| 56 |
self.drop = nn.Dropout2d(0.25)
|
| 57 |
self.fc2 = nn.Linear(600, 100)
|
| 58 |
self.fc3 = nn.Linear(100, 10)
|
| 59 |
|
| 60 |
+
def convs(self, x):
|
| 61 |
+
x = self.layer1(x)
|
| 62 |
+
x = self.layer2(x)
|
| 63 |
+
if self._to_linear is None:
|
| 64 |
+
self._to_linear = x.view(x.size(0), -1).shape[1]
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
def forward(self, x):
|
| 68 |
+
x = self.convs(x)
|
| 69 |
+
x = x.view(x.size(0), -1)
|
| 70 |
+
x = self.fc1(x)
|
| 71 |
+
x = self.drop(x)
|
| 72 |
+
x = self.fc2(x)
|
| 73 |
+
x = self.fc3(x)
|
| 74 |
+
return x
|
|
|
|
| 75 |
|
| 76 |
model = CNN().to(device)
|
| 77 |
|