harry
commited on
Commit
·
5144b79
1
Parent(s):
1fff313
feat: update model training and logging, add mypy cache to .gitignore
Browse files- .gitignore +1 -0
- mnist_classifier/model.py +2 -2
- mnist_classifier/train.py +18 -2
- models/mnist_model_lr0.001_bs64_ep10.pth +3 -0
- torchvision.pyi +9 -0
.gitignore
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
.pytest_cache/
|
|
|
|
| 4 |
wandb/
|
| 5 |
checkpoints/
|
| 6 |
*.egg-info/
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
wandb/
|
| 6 |
checkpoints/
|
| 7 |
*.egg-info/
|
mnist_classifier/model.py
CHANGED
|
@@ -7,8 +7,8 @@ class MNISTModel(nn.Module):
|
|
| 7 |
super().__init__()
|
| 8 |
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
| 9 |
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
| 10 |
-
self.dropout1 = nn.
|
| 11 |
-
self.dropout2 = nn.
|
| 12 |
self.fc1 = nn.Linear(9216, 128)
|
| 13 |
self.fc2 = nn.Linear(128, 10)
|
| 14 |
|
|
|
|
| 7 |
super().__init__()
|
| 8 |
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
| 9 |
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
| 10 |
+
self.dropout1 = nn.Dropout(0.25)
|
| 11 |
+
self.dropout2 = nn.Dropout(0.5)
|
| 12 |
self.fc1 = nn.Linear(9216, 128)
|
| 13 |
self.fc2 = nn.Linear(128, 10)
|
| 14 |
|
mnist_classifier/train.py
CHANGED
|
@@ -5,6 +5,8 @@ from torch.utils.data import DataLoader
|
|
| 5 |
from torch.utils.tensorboard.writer import SummaryWriter
|
| 6 |
from mnist_classifier.dataset import MNISTDataModule
|
| 7 |
from mnist_classifier.model import MNISTModel
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def train():
|
| 10 |
# Set device
|
|
@@ -12,7 +14,8 @@ def train():
|
|
| 12 |
print(f"Using device: {device}")
|
| 13 |
|
| 14 |
# Initialize tensorboard
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
# Setup data
|
| 18 |
data_module = MNISTDataModule(batch_size=64, val_batch_size=1000)
|
|
@@ -24,7 +27,11 @@ def train():
|
|
| 24 |
criterion = nn.CrossEntropyLoss()
|
| 25 |
|
| 26 |
# Training loop
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
for epoch in range(num_epochs):
|
| 29 |
model.train()
|
| 30 |
running_loss = 0.0
|
|
@@ -75,5 +82,14 @@ def train():
|
|
| 75 |
writer.add_scalar('test accuracy', accuracy, epoch)
|
| 76 |
print(f'Epoch {epoch+1}: Test Accuracy: {accuracy:.2f}%')
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if __name__ == "__main__":
|
| 79 |
train()
|
|
|
|
| 5 |
from torch.utils.tensorboard.writer import SummaryWriter
|
| 6 |
from mnist_classifier.dataset import MNISTDataModule
|
| 7 |
from mnist_classifier.model import MNISTModel
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import os
|
| 10 |
|
| 11 |
def train():
|
| 12 |
# Set device
|
|
|
|
| 14 |
print(f"Using device: {device}")
|
| 15 |
|
| 16 |
# Initialize tensorboard
|
| 17 |
+
log_dir = 'runs/mnist_experiment_' + datetime.now().strftime('%Y%m%d-%H%M%S')
|
| 18 |
+
writer = SummaryWriter(log_dir)
|
| 19 |
|
| 20 |
# Setup data
|
| 21 |
data_module = MNISTDataModule(batch_size=64, val_batch_size=1000)
|
|
|
|
| 27 |
criterion = nn.CrossEntropyLoss()
|
| 28 |
|
| 29 |
# Training loop
|
| 30 |
+
learning_rate = 0.001
|
| 31 |
+
batch_size = 64
|
| 32 |
+
epochs = 10
|
| 33 |
+
|
| 34 |
+
num_epochs = epochs
|
| 35 |
for epoch in range(num_epochs):
|
| 36 |
model.train()
|
| 37 |
running_loss = 0.0
|
|
|
|
| 82 |
writer.add_scalar('test accuracy', accuracy, epoch)
|
| 83 |
print(f'Epoch {epoch+1}: Test Accuracy: {accuracy:.2f}%')
|
| 84 |
|
| 85 |
+
writer.close()
|
| 86 |
+
|
| 87 |
+
# Ensure the directory exists
|
| 88 |
+
os.makedirs("./models", exist_ok=True)
|
| 89 |
+
|
| 90 |
+
# Format the filename with the config parameters
|
| 91 |
+
filename = f"./models/mnist_model_lr{learning_rate}_bs{batch_size}_ep{epochs}.pth"
|
| 92 |
+
torch.save(model.state_dict(), filename)
|
| 93 |
+
|
| 94 |
if __name__ == "__main__":
|
| 95 |
train()
|
models/mnist_model_lr0.001_bs64_ep10.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f00fa1ee4fd08e6a5c41d3952b64e27b8bb122182f432332e18c9ee2af67609
|
| 3 |
+
size 4803144
|
torchvision.pyi
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
class datasets:
|
| 4 |
+
MNIST: Any
|
| 5 |
+
|
| 6 |
+
class transforms:
|
| 7 |
+
Compose: Any
|
| 8 |
+
ToTensor: Any
|
| 9 |
+
Normalize: Any
|