Yash Nagraj commited on
Commit
00315bf
·
1 Parent(s): 3c19795

Sampled Image using DDIM in Half the number of timesteps

Browse files
Checkpoints/ckpt_199.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e55c1db65805437bc0443bf33ad9749b03531709f563c88bccc828447460e5c1
3
+ size 332549097
SampledImgs/SampledNoGuidenceImgs1.png ADDED
__pycache__/diffusion.cpython-312.pyc ADDED
Binary file (13.8 kB). View file
 
__pycache__/model.cpython-312.pyc ADDED
Binary file (15.7 kB). View file
 
__pycache__/scheduler.cpython-312.pyc ADDED
Binary file (2.32 kB). View file
 
__pycache__/train.cpython-312.pyc ADDED
Binary file (5.85 kB). View file
 
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from train import train, eval
2
 
3
 
4
  def main(model_config=None):
@@ -6,7 +6,7 @@ def main(model_config=None):
6
  "state": "eval", # or eval
7
  "epochs": 200,
8
  "batch_size": 80,
9
- "T": 1000,
10
  "channel": 128,
11
  "ch_mult": [1, 2, 3, 4],
12
  "attn": [2],
@@ -30,8 +30,10 @@ def main(model_config=None):
30
 
31
  if modelConfig['state'] == 'train':
32
  train(modelConfig)
33
- else:
34
  eval(modelConfig)
 
 
35
 
36
 
37
  if __name__ == "__main__":
 
1
+ from train import eval_ddim, train, eval
2
 
3
 
4
  def main(model_config=None):
 
6
  "state": "eval", # or eval
7
  "epochs": 200,
8
  "batch_size": 80,
9
+ "T": 500,
10
  "channel": 128,
11
  "ch_mult": [1, 2, 3, 4],
12
  "attn": [2],
 
30
 
31
  if modelConfig['state'] == 'train':
32
  train(modelConfig)
33
+ elif modelConfig['state'] == 'eval':
34
  eval(modelConfig)
35
+ elif modelConfig['state'] == 'eval_ddim':
36
+ eval_ddim(modelConfig)
37
 
38
 
39
  if __name__ == "__main__":
train.py CHANGED
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
7
  from torchvision import transforms
8
  from torchvision.datasets import CIFAR10
9
  from torchvision.utils import save_image
10
- from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
11
  from model import UNet
12
  from scheduler import GradualWarmupScheduler
13
 
@@ -84,3 +84,24 @@ def eval(modelConfig:Dict):
84
  nrow = modelConfig['nrow']
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from torchvision import transforms
8
  from torchvision.datasets import CIFAR10
9
  from torchvision.utils import save_image
10
+ from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer, DDIMSampler
11
  from model import UNet
12
  from scheduler import GradualWarmupScheduler
13
 
 
84
  nrow = modelConfig['nrow']
85
  )
86
 
87
+
88
+ def eval_ddim(modelConfig:Dict):
89
+ with torch.no_grad():
90
+ device = torch.device(modelConfig['device'])
91
+ model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
92
+ print("Model loaded")
93
+ model.eval()
94
+ sampler = DDIMSampler(
95
+ modelConfig['beta_1'], modelConfig['beta_T'],
96
+ model,modelConfig['T']
97
+ )
98
+ noisyImage = torch.randn(
99
+ size=[modelConfig['batch_size'],3,32,32],
100
+ device=device
101
+ )
102
+ sampledImgs = sampler(noisyImage)
103
+ sampledImgs = sampledImgs * 0.5 + 0.5
104
+ save_image(sampledImgs,
105
+ os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
106
+ nrow = modelConfig['nrow']
107
+ )