Spaces:
Build error
Build error
File size: 4,390 Bytes
e7610f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np
from torchvision.utils import save_image
import time
import os
import sys
"""
add vqvae and pixelcnn dirs to path
make sure you run from vqvae directory
"""
current_dir = sys.path.append(os.getcwd())
pixelcnn_dir = sys.path.append(os.getcwd()+ '/pixelcnn')
from pixelcnn.models import GatedPixelCNN
import utils
"""
Hyperparameters
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("-save", action="store_true")
parser.add_argument("-gen_samples", action="store_true")
parser.add_argument("--dataset", type=str, default='LATENT_BLOCK')
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--img_dim", type=int, default=64)
parser.add_argument("--input_dim", type=int, default=1,
help='1 for grayscale 3 for rgb')
parser.add_argument("--n_embeddings", type=int, default=3,
help='number of embeddings from VQ VAE')
parser.add_argument("--n_layers", type=int, default=5)
parser.add_argument("--learning_rate", type=float, default=3e-4)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""
data loaders
"""
_, _, train_loader, test_loader, _ = utils.load_data_and_data_loaders('LATENT_BLOCK', args.batch_size)
model = GatedPixelCNN(args.n_embeddings, args.img_dim**2, args.n_layers).to(device)
criterion = nn.CrossEntropyLoss().cuda()
opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
"""
train, test, and log
"""
def train():
train_loss = []
for batch_idx, (x, label) in enumerate(train_loader):
start_time = time.time()
# if args.dataset == 'LATENT_BLOCK':
# x = (x[:, 0]).cuda()
# else:
# x = (x[:, 0] * (K-1)).long().cuda()
x = x.cuda()
label = label.cuda()
# Train PixelCNN with images
logits = model(x, label)
print(logits.shape)
exit(0)
logits = logits.permute(0, 2, 3, 1).contiguous()
loss = criterion(
logits.view(-1, args.n_embeddings),
x.view(-1)
)
opt.zero_grad()
loss.backward()
opt.step()
train_loss.append(loss.item())
if (batch_idx + 1) % args.log_interval == 0:
print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
batch_idx * len(x), len(train_loader.dataset),
args.log_interval * batch_idx / len(train_loader),
np.asarray(train_loss)[-args.log_interval:].mean(0),
time.time() - start_time
))
def test():
start_time = time.time()
val_loss = []
with torch.no_grad():
for batch_idx, (x, label) in enumerate(test_loader):
if args.dataset == 'LATENT_BLOCK':
x = (x[:, 0]).cuda()
else:
x = (x[:, 0] * (args.n_embeddings-1)).long().cuda()
label = label.cuda()
logits = model(x, label)
logits = logits.permute(0, 2, 3, 1).contiguous()
loss = criterion(
logits.view(-1, args.n_embeddings),
x.view(-1)
)
val_loss.append(loss.item())
print('Validation Completed!\tLoss: {} Time: {}'.format(
np.asarray(val_loss).mean(0),
time.time() - start_time
))
return np.asarray(val_loss).mean(0)
def generate_samples(epoch):
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
label = label.long().cuda()
x_tilde = model.generate(label, shape=(args.img_dim,args.img_dim), batch_size=100)
print(x_tilde[0])
BEST_LOSS = 999
LAST_SAVED = -1
for epoch in range(1, args.epochs):
print("\nEpoch {}:".format(epoch))
train()
cur_loss = test()
if args.save or cur_loss <= BEST_LOSS:
BEST_LOSS = cur_loss
LAST_SAVED = epoch
print("Saving model!")
torch.save(model.state_dict(), 'results/{}_pixelcnn.pt'.format(args.dataset))
else:
print("Not saving model! Last saved: {}".format(LAST_SAVED))
if args.gen_samples:
generate_samples(epoch)
|