johnowhitaker commited on
Commit
185c7b0
·
1 Parent(s): c09ece0

Upload train_dvq_diff.py

Browse files
Files changed (1) hide show
  1. train_dvq_diff.py +151 -0
train_dvq_diff.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('cclddg')
3
+ import wandb
4
+ from cclddg.data import get_paired_vqgan, tensor_to_image
5
+ from cclddg.core import UNet, Discriminator
6
+ from cclddg.ddg_context import DDG_Context
7
+ from PIL import Image
8
+ import torch
9
+ import torchvision.transforms as T
10
+ from torch_ema import ExponentialMovingAverage # pip install torch-ema
11
+ from tqdm import tqdm
12
+ import torch.nn.functional as F
13
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14
+
15
+ # Training params
16
+ n_batches = 101000
17
+ batch_size= 5 # Lower this if hitting memory issues
18
+ lr = 5e-5
19
+ img_size=128
20
+ sr=1
21
+ n_steps=200 # Should try more
22
+ grad_accumulation_steps = 6 # batch accumulation parameter
23
+
24
+ wandb.init(project = 'dvq_diff',
25
+ config={
26
+ 'n_batches':n_batches,
27
+ 'batch_size':batch_size,
28
+ 'lr':lr,
29
+ 'img_size':img_size,
30
+ 'sr':sr,
31
+ 'n_steps':n_steps,
32
+ },
33
+ save_code=True)
34
+
35
+ # Context
36
+ ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005,
37
+ beta_max=0.05, device=device)
38
+
39
+ # Model
40
+ unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2),
41
+ is_attn=(False, False, False, True, True),
42
+ n_blocks=4, use_z=False, z_dim=8, n_z_channels=16,
43
+ use_cloob=False, n_cloob_channels=256,
44
+ n_time_channels=-1, denom_factor=1000).to(device)
45
+ unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt'))
46
+
47
+
48
+ if sr == 4:
49
+ # Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256
50
+ lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)])
51
+ hq_tfm = T.CenterCrop(img_size)
52
+ if sr == 2:
53
+ lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)])
54
+ hq_tfm = T.CenterCrop(img_size)
55
+ if sr == 1:
56
+ lq_tfm = T.Compose([T.Resize(img_size)])
57
+ hq_tfm = T.Compose([T.Resize(img_size)])
58
+
59
+
60
+ # Data
61
+ data = get_paired_vqgan(batch_size=batch_size)
62
+ data_iter = iter(data)
63
+
64
+ # For logging examples
65
+ n_egs = 10
66
+ eg_lq, eg_hq = next(data_iter)
67
+ eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1
68
+ eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1
69
+ def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99):
70
+ batch_size = eg_lq.shape[0]
71
+ all_ims = [[] for _ in range(batch_size)]
72
+
73
+ # Start from noised cond_0
74
+ cond_0 = eg_lq
75
+ start_t = min(start_t, ddg_context.n_steps-1)
76
+ t = torch.tensor(start_t, dtype=torch.long).cuda()
77
+ x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0))
78
+ ims = []
79
+ for i in range(start_t):
80
+ t = torch.tensor(start_t-i-1, dtype=torch.long).cuda()
81
+ with torch.no_grad():
82
+ unet_input = torch.cat((x, cond_0), dim=1)
83
+ pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3]
84
+ x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0))
85
+ if i%(start_t//4 - 1) == 0:
86
+ for b in range(batch_size):
87
+ all_ims[b].append(tensor_to_image(x[b].cpu()))
88
+
89
+ # HQ target:
90
+ for b in range(batch_size):
91
+ all_ims[b].append(tensor_to_image(eg_hq[b].cpu()))
92
+ # Input/cond:
93
+ for b in range(batch_size):
94
+ all_ims[b].append(tensor_to_image(cond_0[b].cpu()))
95
+
96
+ image = Image.new('RGB', size=(img_size*7, batch_size*img_size))
97
+ for i in range(7):
98
+ for b in range(batch_size):
99
+ image.paste(all_ims[b][i], (i*img_size, b*img_size))
100
+
101
+ return image
102
+
103
+ # Training Loop
104
+ losses = [] # Store losses for later plotting
105
+ optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer
106
+ ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA
107
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)
108
+
109
+ for i in tqdm(range(0, n_batches)): # Run through the dataset
110
+
111
+ # Get a batch
112
+ try:
113
+ lq, hq = next(data_iter)
114
+ except:
115
+ pass
116
+ lq = lq_tfm(lq).to(device)*2-1
117
+ hq = hq_tfm(hq).to(device)*2-1
118
+ batch_size=lq.shape[0]
119
+
120
+
121
+ x0 = hq
122
+ cond_0 = lq
123
+ t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's
124
+ xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target)
125
+ unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond
126
+ pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions
127
+ loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets
128
+ losses.append(loss.item()) # Store the loss for later viewing
129
+ wandb.log({'Loss':loss.item()}) # Log to wandb
130
+ loss.backward() # Backpropagate the loss
131
+
132
+ if i % grad_accumulation_steps == 0:
133
+ optim.step() # Update the network parameters
134
+ optim.zero_grad() # Zero the gradients
135
+ ema.update() # Update the moving average with the new parameters from the last optimizer step
136
+
137
+ if i % 2000 == 0:
138
+ with torch.no_grad():
139
+ wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))})
140
+ wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))})
141
+ wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))})
142
+
143
+ if i % 20000 == 0:
144
+ torch.save(unet.state_dict(), f'unet_{i:06}.pt')
145
+ with ema.average_parameters():
146
+ torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt')
147
+
148
+ if (i+1)%4000 == 0:
149
+ scheduler.step()
150
+ wandb.log({'lr':optim.param_groups[0]['lr']})
151
+