Yash Nagraj commited on
Commit
b5ea21b
·
1 Parent(s): e1dc9af
__pycache__/diffusion.cpython-312.pyc ADDED
Binary file (6.52 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (15.5 kB). View file
 
__pycache__/scheduler.cpython-312.pyc ADDED
Binary file (2.32 kB). View file
 
__pycache__/train.cpython-312.pyc ADDED
Binary file (4.85 kB). View file
 
cifar-10-python.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53c8b1a8d9b386502f69a87ab4d636d25a893ba9c434be754980ec283a5bd227
3
+ size 3768320
main.py CHANGED
@@ -28,7 +28,7 @@ def main(model_config=None):
28
  "nrow": 8
29
  }
30
 
31
- if modelConfig['state'] == train:
32
  train(modelConfig)
33
  else:
34
  eval(modelConfig)
 
28
  "nrow": 8
29
  }
30
 
31
+ if modelConfig['state'] == 'train':
32
  train(modelConfig)
33
  else:
34
  eval(modelConfig)
train.py CHANGED
@@ -15,7 +15,7 @@ def train(modelConfig: Dict):
15
  device = torch.device(modelConfig['device'])
16
 
17
  dataset = CIFAR10(
18
- "./",train=True,transform=transforms.Compose([
19
  transforms.RandomHorizontalFlip(),
20
  transforms.ToTensor(),
21
  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
@@ -60,7 +60,7 @@ def train(modelConfig: Dict):
60
  "LR": optimizer.state_dict()['param_groups'][0]["lr"]
61
  })
62
  warmupScheduler.step()
63
- torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}_.pth"))
64
 
65
 
66
  def eval(modelConfig:Dict):
 
15
  device = torch.device(modelConfig['device'])
16
 
17
  dataset = CIFAR10(
18
+ "./",train=True,download=True,transform=transforms.Compose([
19
  transforms.RandomHorizontalFlip(),
20
  transforms.ToTensor(),
21
  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
 
60
  "LR": optimizer.state_dict()['param_groups'][0]["lr"]
61
  })
62
  warmupScheduler.step()
63
+ torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth"))
64
 
65
 
66
  def eval(modelConfig:Dict):