File size: 4,059 Bytes
990d40a
 
 
 
 
 
 
 
 
00315bf
990d40a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00315bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from typing import Dict
import torch 
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer, DDIMSampler
from model import UNet
from scheduler import GradualWarmupScheduler

def train(modelConfig: Dict):
    device = torch.device(modelConfig['device'])

    dataset = CIFAR10(
        "./",train=True,download=True,transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ])
    )

    dataloader = DataLoader(dataset,batch_size=modelConfig['batch_size'],shuffle=True, num_workers=4,drop_last=True,pin_memory=True)

    net_model = UNet(modelConfig['T'],modelConfig['channel'],modelConfig['ch_mult'],modelConfig['attn'],
                     modelConfig['num_res_blocks'],modelConfig['dropout'])
    optimizer = optim.AdamW(net_model.parameters(),modelConfig['lr'],weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer,modelConfig['epochs'],eta_min=0,last_epoch=-1
    )

    warmupScheduler = GradualWarmupScheduler(
        optimizer,modelConfig['multiplier'],modelConfig['epochs'] // 10,
        cosineScheduler
    )

    trainer = GaussianDiffusionTrainer(
        modelConfig['beta_1'],
        modelConfig['beta_T'],
        modelConfig['T'],
        net_model).to(device)
    
    for epoch in range(modelConfig['epochs']):
        with tqdm(dataloader,dynamic_ncols=True) as tqdmDataLoader:
            for images,_ in  tqdmDataLoader:
                optimizer.zero_grad()
                x_0 = images.to(device)
                loss = trainer(x_0).sum() / 1000
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(),modelConfig['grad_clip']
                )
                optimizer.step()
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": epoch,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })
        warmupScheduler.step()
        torch.save(net_model,os.path.join(modelConfig['checkpoint_dir'] + f"ckpt_{epoch}.pth"))


def eval(modelConfig:Dict):
    with torch.no_grad():
        device = torch.device(modelConfig['device'])
        model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
        print("Model loaded")
        model.eval()
        sampler = GaussianDiffusionSampler(
            modelConfig['beta_1'], modelConfig['beta_T'],
            model,modelConfig['T']
        )
        noisyImage = torch.randn(
            size=[modelConfig['batch_size'],3,32,32],
            device=device
        )
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs * 0.5 + 0.5
        save_image(sampledImgs,
            os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
            nrow = modelConfig['nrow']
        )


def eval_ddim(modelConfig:Dict):
    with torch.no_grad():
        device = torch.device(modelConfig['device'])
        model = torch.load(os.path.join(modelConfig['checkpoint_dir'],modelConfig['test_load_weight']),device)
        print("Model loaded")
        model.eval()
        sampler = DDIMSampler(
            modelConfig['beta_1'], modelConfig['beta_T'],
            model,modelConfig['T']
        )
        noisyImage = torch.randn(
            size=[modelConfig['batch_size'],3,32,32],
            device=device
        )
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs * 0.5 + 0.5
        save_image(sampledImgs,
            os.path.join(modelConfig['sample_dir'],modelConfig['sampledImgName']),
            nrow = modelConfig['nrow']
        )