harry
commited on
Commit
·
55af1cd
1
Parent(s):
5144b79
feat: add seed setting for deterministic training
Browse files
mnist_classifier/train.py
CHANGED
|
@@ -7,8 +7,22 @@ from mnist_classifier.dataset import MNISTDataModule
|
|
| 7 |
from mnist_classifier.model import MNISTModel
|
| 8 |
from datetime import datetime
|
| 9 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def train():
|
|
|
|
|
|
|
|
|
|
| 12 |
# Set device
|
| 13 |
device = torch.device('cuda')
|
| 14 |
print(f"Using device: {device}")
|
|
|
|
| 7 |
from mnist_classifier.model import MNISTModel
|
| 8 |
from datetime import datetime
|
| 9 |
import os
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
def set_seed(seed):
|
| 14 |
+
torch.manual_seed(seed)
|
| 15 |
+
torch.cuda.manual_seed(seed)
|
| 16 |
+
torch.cuda.manual_seed_all(seed)
|
| 17 |
+
np.random.seed(seed)
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
torch.backends.cudnn.deterministic = True
|
| 20 |
+
torch.backends.cudnn.benchmark = False
|
| 21 |
|
| 22 |
def train():
|
| 23 |
+
# Set seed for reproducibility
|
| 24 |
+
set_seed(42)
|
| 25 |
+
|
| 26 |
# Set device
|
| 27 |
device = torch.device('cuda')
|
| 28 |
print(f"Using device: {device}")
|
models/mnist_model_lr0.001_bs64_ep10.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4803144
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1b474acf8a447dea4e3aaaf0371346ee7a7055d1c716fb371c059b9a1799bab
|
| 3 |
size 4803144
|