Spaces:
Build error
Build error
Commit ·
10385fa
1
Parent(s): 3bdf51a
Remove unused lines
Browse files- source/model.py +0 -26
source/model.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch.nn.functional as F
|
| 3 |
-
from dataset import get_paths, get_data_loader, Dataset
|
| 4 |
-
from setup import Setup
|
| 5 |
|
| 6 |
|
| 7 |
class CNN(nn.Module):
|
|
@@ -75,27 +73,3 @@ class CNN(nn.Module):
|
|
| 75 |
# print('Out: ', x.size())
|
| 76 |
return F.log_softmax(x, dim=1)
|
| 77 |
|
| 78 |
-
if __name__ == '__main__':
|
| 79 |
-
"""
|
| 80 |
-
Main script to initialize the setup, load datasets, create DataLoader,
|
| 81 |
-
instantiate the CNN model, and display the number of trainable parameters
|
| 82 |
-
and the output size for a batch of images.
|
| 83 |
-
"""
|
| 84 |
-
|
| 85 |
-
setup = Setup()
|
| 86 |
-
|
| 87 |
-
normal_train_paths, red_train_paths, normal_test_paths, red_test_paths = get_paths()
|
| 88 |
-
|
| 89 |
-
train_dataset = Dataset(red_train_paths, normal_train_paths)
|
| 90 |
-
train_loader = get_data_loader(train_dataset, batch_size=setup.BATCH)
|
| 91 |
-
|
| 92 |
-
imgs, labels = next(iter(train_loader))
|
| 93 |
-
|
| 94 |
-
cnn = CNN()
|
| 95 |
-
print(f'Number of trainable parameters in CNN: {sum(p.numel() for p in cnn.parameters() if p.requires_grad)}')
|
| 96 |
-
output = cnn.forward(imgs)
|
| 97 |
-
|
| 98 |
-
# Print info
|
| 99 |
-
print('\nBatch size: ', setup.BATCH)
|
| 100 |
-
print('Images size: ', imgs.size()) # (batch, 3, 32, 32)
|
| 101 |
-
print('CNN output size: ', output.size()) # (batch, 2)
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class CNN(nn.Module):
|
|
|
|
| 73 |
# print('Out: ', x.size())
|
| 74 |
return F.log_softmax(x, dim=1)
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|