Yash Nagraj commited on
Commit ·
b5ea21b
1
Parent(s): e1dc9af
Dumb fix
Browse files- __pycache__/diffusion.cpython-312.pyc +0 -0
- __pycache__/model.cpython-312.pyc +0 -0
- __pycache__/scheduler.cpython-312.pyc +0 -0
- __pycache__/train.cpython-312.pyc +0 -0
- cifar-10-python.tar.gz +3 -0
- main.py +1 -1
- train.py +2 -2
__pycache__/diffusion.cpython-312.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
cifar-10-python.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53c8b1a8d9b386502f69a87ab4d636d25a893ba9c434be754980ec283a5bd227
|
| 3 |
+
size 3768320
|
main.py
CHANGED
|
@@ -28,7 +28,7 @@ def main(model_config=None):
|
|
| 28 |
"nrow": 8
|
| 29 |
}
|
| 30 |
|
| 31 |
-
if modelConfig['state'] == train:
|
| 32 |
train(modelConfig)
|
| 33 |
else:
|
| 34 |
eval(modelConfig)
|
|
|
|
| 28 |
"nrow": 8
|
| 29 |
}
|
| 30 |
|
| 31 |
+
if modelConfig['state'] == 'train':
|
| 32 |
train(modelConfig)
|
| 33 |
else:
|
| 34 |
eval(modelConfig)
|
train.py
CHANGED
|
@@ -15,7 +15,7 @@ def train(modelConfig: Dict):
|
|
| 15 |
device = torch.device(modelConfig['device'])
|
| 16 |
|
| 17 |
dataset = CIFAR10(
|
| 18 |
-
"./",train=True,transform=transforms.Compose([
|
| 19 |
transforms.RandomHorizontalFlip(),
|
| 20 |
transforms.ToTensor(),
|
| 21 |
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
|
|
@@ -60,7 +60,7 @@ def train(modelConfig: Dict):
|
|
| 60 |
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
|
| 61 |
})
|
| 62 |
warmupScheduler.step()
|
| 63 |
-
torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}
|
| 64 |
|
| 65 |
|
| 66 |
def eval(modelConfig:Dict):
|
|
|
|
| 15 |
device = torch.device(modelConfig['device'])
|
| 16 |
|
| 17 |
dataset = CIFAR10(
|
| 18 |
+
"./",train=True,download=True,transform=transforms.Compose([
|
| 19 |
transforms.RandomHorizontalFlip(),
|
| 20 |
transforms.ToTensor(),
|
| 21 |
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
|
|
|
|
| 60 |
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
|
| 61 |
})
|
| 62 |
warmupScheduler.step()
|
| 63 |
+
torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth"))
|
| 64 |
|
| 65 |
|
| 66 |
def eval(modelConfig:Dict):
|