import torch import torch.nn as nn from torchvision import datasets, transforms import numpy as np from torchvision.utils import save_image import time import os import sys """ add vqvae and pixelcnn dirs to path make sure you run from vqvae directory """ current_dir = sys.path.append(os.getcwd()) pixelcnn_dir = sys.path.append(os.getcwd()+ '/pixelcnn') from pixelcnn.models import GatedPixelCNN import utils """ Hyperparameters """ import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("-save", action="store_true") parser.add_argument("-gen_samples", action="store_true") parser.add_argument("--dataset", type=str, default='LATENT_BLOCK') parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--img_dim", type=int, default=64) parser.add_argument("--input_dim", type=int, default=1, help='1 for grayscale 3 for rgb') parser.add_argument("--n_embeddings", type=int, default=3, help='number of embeddings from VQ VAE') parser.add_argument("--n_layers", type=int, default=5) parser.add_argument("--learning_rate", type=float, default=3e-4) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ data loaders """ _, _, train_loader, test_loader, _ = utils.load_data_and_data_loaders('LATENT_BLOCK', args.batch_size) model = GatedPixelCNN(args.n_embeddings, args.img_dim**2, args.n_layers).to(device) criterion = nn.CrossEntropyLoss().cuda() opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate) """ train, test, and log """ def train(): train_loss = [] for batch_idx, (x, label) in enumerate(train_loader): start_time = time.time() # if args.dataset == 'LATENT_BLOCK': # x = (x[:, 0]).cuda() # else: # x = (x[:, 0] * (K-1)).long().cuda() x = x.cuda() label = label.cuda() # Train PixelCNN with images logits = model(x, label) print(logits.shape) exit(0) logits = logits.permute(0, 2, 3, 1).contiguous() loss = criterion( logits.view(-1, args.n_embeddings), x.view(-1) ) opt.zero_grad() loss.backward() opt.step() train_loss.append(loss.item()) if (batch_idx + 1) % args.log_interval == 0: print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format( batch_idx * len(x), len(train_loader.dataset), args.log_interval * batch_idx / len(train_loader), np.asarray(train_loss)[-args.log_interval:].mean(0), time.time() - start_time )) def test(): start_time = time.time() val_loss = [] with torch.no_grad(): for batch_idx, (x, label) in enumerate(test_loader): if args.dataset == 'LATENT_BLOCK': x = (x[:, 0]).cuda() else: x = (x[:, 0] * (args.n_embeddings-1)).long().cuda() label = label.cuda() logits = model(x, label) logits = logits.permute(0, 2, 3, 1).contiguous() loss = criterion( logits.view(-1, args.n_embeddings), x.view(-1) ) val_loss.append(loss.item()) print('Validation Completed!\tLoss: {} Time: {}'.format( np.asarray(val_loss).mean(0), time.time() - start_time )) return np.asarray(val_loss).mean(0) def generate_samples(epoch): label = torch.arange(10).expand(10, 10).contiguous().view(-1) label = label.long().cuda() x_tilde = model.generate(label, shape=(args.img_dim,args.img_dim), batch_size=100) print(x_tilde[0]) BEST_LOSS = 999 LAST_SAVED = -1 for epoch in range(1, args.epochs): print("\nEpoch {}:".format(epoch)) train() cur_loss = test() if args.save or cur_loss <= BEST_LOSS: BEST_LOSS = cur_loss LAST_SAVED = epoch print("Saving model!") torch.save(model.state_dict(), 'results/{}_pixelcnn.pt'.format(args.dataset)) else: print("Not saving model! Last saved: {}".format(LAST_SAVED)) if args.gen_samples: generate_samples(epoch)