Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -101,19 +101,14 @@ class Net(nn.Module):
|
|
| 101 |
return x
|
| 102 |
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
|
| 109 |
-
|
| 110 |
-
# downloads and loads MNIST test set
|
| 111 |
-
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
| 112 |
-
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
| 113 |
-
|
| 114 |
-
# gets mean and std of dataset
|
| 115 |
-
mean, std = get_mean_std(train_loader)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# uses GPU if available
|
| 119 |
if torch.cuda.is_available():
|
|
@@ -123,6 +118,9 @@ else:
|
|
| 123 |
|
| 124 |
device = torch.device(dev)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def run_model():
|
| 128 |
# defines parameters
|
|
|
|
| 101 |
return x
|
| 102 |
|
| 103 |
|
| 104 |
+
# downloads and loads MNIST train set
|
| 105 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.RandomAffine(degrees=10, translate=(0.1,0.1))])
|
| 106 |
+
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
|
| 107 |
+
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, pin_memory=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
# downloads and loads MNIST test set
|
| 110 |
+
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
|
| 111 |
+
val_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=False, pin_memory=True)
|
| 112 |
|
| 113 |
# uses GPU if available
|
| 114 |
if torch.cuda.is_available():
|
|
|
|
| 118 |
|
| 119 |
device = torch.device(dev)
|
| 120 |
|
| 121 |
+
# gets mean and std of dataset
|
| 122 |
+
mean, std = get_mean_std(train_loader)
|
| 123 |
+
|
| 124 |
|
| 125 |
def run_model():
|
| 126 |
# defines parameters
|