Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -110,6 +110,10 @@ def download_data():
|
|
| 110 |
# downloads and loads MNIST test set
|
| 111 |
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
| 112 |
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# uses GPU if available
|
| 115 |
if torch.cuda.is_available():
|
|
@@ -119,9 +123,6 @@ else:
|
|
| 119 |
|
| 120 |
device = torch.device(dev)
|
| 121 |
|
| 122 |
-
# gets mean and std of dataset
|
| 123 |
-
mean, std = get_mean_std(train_loader)
|
| 124 |
-
|
| 125 |
|
| 126 |
def run_model():
|
| 127 |
# defines parameters
|
|
|
|
| 110 |
# downloads and loads MNIST test set
|
| 111 |
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
| 112 |
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
| 113 |
+
|
| 114 |
+
# gets mean and std of dataset
|
| 115 |
+
mean, std = get_mean_std(train_loader)
|
| 116 |
+
|
| 117 |
|
| 118 |
# uses GPU if available
|
| 119 |
if torch.cuda.is_available():
|
|
|
|
| 123 |
|
| 124 |
device = torch.device(dev)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def run_model():
|
| 128 |
# defines parameters
|