Yash Nagraj commited on
Commit ·
00315bf
1
Parent(s): 3c19795
Sampled Image using DDIM in Half the number of timesteps
Browse files- Checkpoints/ckpt_199.pth +3 -0
- SampledImgs/SampledNoGuidenceImgs1.png +0 -0
- __pycache__/diffusion.cpython-312.pyc +0 -0
- __pycache__/model.cpython-312.pyc +0 -0
- __pycache__/scheduler.cpython-312.pyc +0 -0
- __pycache__/train.cpython-312.pyc +0 -0
- main.py +5 -3
- train.py +22 -1
Checkpoints/ckpt_199.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e55c1db65805437bc0443bf33ad9749b03531709f563c88bccc828447460e5c1
|
| 3 |
+
size 332549097
|
SampledImgs/SampledNoGuidenceImgs1.png
ADDED
|
__pycache__/diffusion.cpython-312.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
main.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from train import train, eval
|
| 2 |
|
| 3 |
|
| 4 |
def main(model_config=None):
|
|
@@ -6,7 +6,7 @@ def main(model_config=None):
|
|
| 6 |
"state": "eval", # or eval
|
| 7 |
"epochs": 200,
|
| 8 |
"batch_size": 80,
|
| 9 |
-
"T":
|
| 10 |
"channel": 128,
|
| 11 |
"ch_mult": [1, 2, 3, 4],
|
| 12 |
"attn": [2],
|
|
@@ -30,8 +30,10 @@ def main(model_config=None):
|
|
| 30 |
|
| 31 |
if modelConfig['state'] == 'train':
|
| 32 |
train(modelConfig)
|
| 33 |
-
|
| 34 |
eval(modelConfig)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
from train import eval_ddim, train, eval
|
| 2 |
|
| 3 |
|
| 4 |
def main(model_config=None):
|
|
|
|
| 6 |
"state": "eval", # or eval
|
| 7 |
"epochs": 200,
|
| 8 |
"batch_size": 80,
|
| 9 |
+
"T": 500,
|
| 10 |
"channel": 128,
|
| 11 |
"ch_mult": [1, 2, 3, 4],
|
| 12 |
"attn": [2],
|
|
|
|
| 30 |
|
| 31 |
if modelConfig['state'] == 'train':
|
| 32 |
train(modelConfig)
|
| 33 |
+
elif modelConfig['state'] == 'eval':
|
| 34 |
eval(modelConfig)
|
| 35 |
+
elif modelConfig['state'] == 'eval_ddim':
|
| 36 |
+
eval_ddim(modelConfig)
|
| 37 |
|
| 38 |
|
| 39 |
if __name__ == "__main__":
|
train.py
CHANGED
|
@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
|
|
| 7 |
from torchvision import transforms
|
| 8 |
from torchvision.datasets import CIFAR10
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
-
from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
|
| 11 |
from model import UNet
|
| 12 |
from scheduler import GradualWarmupScheduler
|
| 13 |
|
|
@@ -84,3 +84,24 @@ def eval(modelConfig:Dict):
|
|
| 84 |
nrow = modelConfig['nrow']
|
| 85 |
)
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
from torchvision.datasets import CIFAR10
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
+
from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer, DDIMSampler
|
| 11 |
from model import UNet
|
| 12 |
from scheduler import GradualWarmupScheduler
|
| 13 |
|
|
|
|
| 84 |
nrow = modelConfig['nrow']
|
| 85 |
)
|
| 86 |
|
| 87 |
+
|
| 88 |
+
def eval_ddim(modelConfig:Dict):
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
device = torch.device(modelConfig['device'])
|
| 91 |
+
model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
|
| 92 |
+
print("Model loaded")
|
| 93 |
+
model.eval()
|
| 94 |
+
sampler = DDIMSampler(
|
| 95 |
+
modelConfig['beta_1'], modelConfig['beta_T'],
|
| 96 |
+
model,modelConfig['T']
|
| 97 |
+
)
|
| 98 |
+
noisyImage = torch.randn(
|
| 99 |
+
size=[modelConfig['batch_size'],3,32,32],
|
| 100 |
+
device=device
|
| 101 |
+
)
|
| 102 |
+
sampledImgs = sampler(noisyImage)
|
| 103 |
+
sampledImgs = sampledImgs * 0.5 + 0.5
|
| 104 |
+
save_image(sampledImgs,
|
| 105 |
+
os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
|
| 106 |
+
nrow = modelConfig['nrow']
|
| 107 |
+
)
|