Commit
·
185c7b0
1
Parent(s):
c09ece0
Upload train_dvq_diff.py
Browse files- train_dvq_diff.py +151 -0
train_dvq_diff.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('cclddg')
|
| 3 |
+
import wandb
|
| 4 |
+
from cclddg.data import get_paired_vqgan, tensor_to_image
|
| 5 |
+
from cclddg.core import UNet, Discriminator
|
| 6 |
+
from cclddg.ddg_context import DDG_Context
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torch_ema import ExponentialMovingAverage # pip install torch-ema
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
|
| 15 |
+
# Training params
|
| 16 |
+
n_batches = 101000
|
| 17 |
+
batch_size= 5 # Lower this if hitting memory issues
|
| 18 |
+
lr = 5e-5
|
| 19 |
+
img_size=128
|
| 20 |
+
sr=1
|
| 21 |
+
n_steps=200 # Should try more
|
| 22 |
+
grad_accumulation_steps = 6 # batch accumulation parameter
|
| 23 |
+
|
| 24 |
+
wandb.init(project = 'dvq_diff',
|
| 25 |
+
config={
|
| 26 |
+
'n_batches':n_batches,
|
| 27 |
+
'batch_size':batch_size,
|
| 28 |
+
'lr':lr,
|
| 29 |
+
'img_size':img_size,
|
| 30 |
+
'sr':sr,
|
| 31 |
+
'n_steps':n_steps,
|
| 32 |
+
},
|
| 33 |
+
save_code=True)
|
| 34 |
+
|
| 35 |
+
# Context
|
| 36 |
+
ddg_context = DDG_Context(n_steps=n_steps, beta_min=0.005,
|
| 37 |
+
beta_max=0.05, device=device)
|
| 38 |
+
|
| 39 |
+
# Model
|
| 40 |
+
unet = UNet(image_channels=6, n_channels=128, ch_mults=(1, 1, 2, 2, 2),
|
| 41 |
+
is_attn=(False, False, False, True, True),
|
| 42 |
+
n_blocks=4, use_z=False, z_dim=8, n_z_channels=16,
|
| 43 |
+
use_cloob=False, n_cloob_channels=256,
|
| 44 |
+
n_time_channels=-1, denom_factor=1000).to(device)
|
| 45 |
+
unet.load_state_dict(torch.load('desert_dawn_ema_unet_020000.pt'))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if sr == 4:
|
| 49 |
+
# Goal is 4x SR. If image size is 256 (hq) we take 128px from lq (which is already 1/2 res) and scale to 64px then back up to 256
|
| 50 |
+
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size//4), T.Resize(img_size)])
|
| 51 |
+
hq_tfm = T.CenterCrop(img_size)
|
| 52 |
+
if sr == 2:
|
| 53 |
+
lq_tfm = T.Compose([T.CenterCrop(img_size//2), T.Resize(img_size)])
|
| 54 |
+
hq_tfm = T.CenterCrop(img_size)
|
| 55 |
+
if sr == 1:
|
| 56 |
+
lq_tfm = T.Compose([T.Resize(img_size)])
|
| 57 |
+
hq_tfm = T.Compose([T.Resize(img_size)])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Data
|
| 61 |
+
data = get_paired_vqgan(batch_size=batch_size)
|
| 62 |
+
data_iter = iter(data)
|
| 63 |
+
|
| 64 |
+
# For logging examples
|
| 65 |
+
n_egs = 10
|
| 66 |
+
eg_lq, eg_hq = next(data_iter)
|
| 67 |
+
eg_lq = lq_tfm(eg_lq[:n_egs]).to(device)*2-1
|
| 68 |
+
eg_hq = hq_tfm(eg_hq[:n_egs]).to(device)*2-1
|
| 69 |
+
def eg_im(eg_lq, eg_hq, ddg_context, start_t = 99):
|
| 70 |
+
batch_size = eg_lq.shape[0]
|
| 71 |
+
all_ims = [[] for _ in range(batch_size)]
|
| 72 |
+
|
| 73 |
+
# Start from noised cond_0
|
| 74 |
+
cond_0 = eg_lq
|
| 75 |
+
start_t = min(start_t, ddg_context.n_steps-1)
|
| 76 |
+
t = torch.tensor(start_t, dtype=torch.long).cuda()
|
| 77 |
+
x, n = ddg_context.q_xt_x0(cond_0, t.unsqueeze(0))
|
| 78 |
+
ims = []
|
| 79 |
+
for i in range(start_t):
|
| 80 |
+
t = torch.tensor(start_t-i-1, dtype=torch.long).cuda()
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
unet_input = torch.cat((x, cond_0), dim=1)
|
| 83 |
+
pred_noise = unet(unet_input, t.unsqueeze(0))[:,:3]
|
| 84 |
+
x = ddg_context.p_xt(x, pred_noise, t.unsqueeze(0))
|
| 85 |
+
if i%(start_t//4 - 1) == 0:
|
| 86 |
+
for b in range(batch_size):
|
| 87 |
+
all_ims[b].append(tensor_to_image(x[b].cpu()))
|
| 88 |
+
|
| 89 |
+
# HQ target:
|
| 90 |
+
for b in range(batch_size):
|
| 91 |
+
all_ims[b].append(tensor_to_image(eg_hq[b].cpu()))
|
| 92 |
+
# Input/cond:
|
| 93 |
+
for b in range(batch_size):
|
| 94 |
+
all_ims[b].append(tensor_to_image(cond_0[b].cpu()))
|
| 95 |
+
|
| 96 |
+
image = Image.new('RGB', size=(img_size*7, batch_size*img_size))
|
| 97 |
+
for i in range(7):
|
| 98 |
+
for b in range(batch_size):
|
| 99 |
+
image.paste(all_ims[b][i], (i*img_size, b*img_size))
|
| 100 |
+
|
| 101 |
+
return image
|
| 102 |
+
|
| 103 |
+
# Training Loop
|
| 104 |
+
losses = [] # Store losses for later plotting
|
| 105 |
+
optim = torch.optim.RMSprop(unet.parameters(), lr=lr) # Optimizer
|
| 106 |
+
ema = ExponentialMovingAverage(unet.parameters(), decay=0.995) # EMA
|
| 107 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)
|
| 108 |
+
|
| 109 |
+
for i in tqdm(range(0, n_batches)): # Run through the dataset
|
| 110 |
+
|
| 111 |
+
# Get a batch
|
| 112 |
+
try:
|
| 113 |
+
lq, hq = next(data_iter)
|
| 114 |
+
except:
|
| 115 |
+
pass
|
| 116 |
+
lq = lq_tfm(lq).to(device)*2-1
|
| 117 |
+
hq = hq_tfm(hq).to(device)*2-1
|
| 118 |
+
batch_size=lq.shape[0]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
x0 = hq
|
| 122 |
+
cond_0 = lq
|
| 123 |
+
t = torch.randint(1, ddg_context.n_steps, (batch_size,), dtype=torch.long).to(device) # Random 't's
|
| 124 |
+
xt, noise = ddg_context.q_xt_x0(x0, t) # Get the noised images (xt) and the noise (our target)
|
| 125 |
+
unet_input = torch.cat((xt, cond_0), dim=1) # Combine with cond
|
| 126 |
+
pred_noise = unet(unet_input, t)[:,:3] # Run xt through the network to get its predictions
|
| 127 |
+
loss = F.mse_loss(noise.float(), pred_noise) # Compare the predictions with the targets
|
| 128 |
+
losses.append(loss.item()) # Store the loss for later viewing
|
| 129 |
+
wandb.log({'Loss':loss.item()}) # Log to wandb
|
| 130 |
+
loss.backward() # Backpropagate the loss
|
| 131 |
+
|
| 132 |
+
if i % grad_accumulation_steps == 0:
|
| 133 |
+
optim.step() # Update the network parameters
|
| 134 |
+
optim.zero_grad() # Zero the gradients
|
| 135 |
+
ema.update() # Update the moving average with the new parameters from the last optimizer step
|
| 136 |
+
|
| 137 |
+
if i % 2000 == 0:
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
wandb.log({'Examples @120':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 120))})
|
| 140 |
+
wandb.log({'Examples @199':wandb.Image(eg_im(eg_lq, eg_hq, ddg_context, start_t = 199))})
|
| 141 |
+
wandb.log({'Random Examples @120':wandb.Image(eg_im(lq, hq, ddg_context, start_t = 120))})
|
| 142 |
+
|
| 143 |
+
if i % 20000 == 0:
|
| 144 |
+
torch.save(unet.state_dict(), f'unet_{i:06}.pt')
|
| 145 |
+
with ema.average_parameters():
|
| 146 |
+
torch.save(unet.state_dict(), f'ema_unet_{i:06}.pt')
|
| 147 |
+
|
| 148 |
+
if (i+1)%4000 == 0:
|
| 149 |
+
scheduler.step()
|
| 150 |
+
wandb.log({'lr':optim.param_groups[0]['lr']})
|
| 151 |
+
|