Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from torchvision import datasets, transforms | |
| import numpy as np | |
| from torchvision.utils import save_image | |
| import time | |
| import os | |
| import sys | |
| """ | |
| add vqvae and pixelcnn dirs to path | |
| make sure you run from vqvae directory | |
| """ | |
| current_dir = sys.path.append(os.getcwd()) | |
| pixelcnn_dir = sys.path.append(os.getcwd()+ '/pixelcnn') | |
| from pixelcnn.models import GatedPixelCNN | |
| import utils | |
| """ | |
| Hyperparameters | |
| """ | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--batch_size", type=int, default=1) | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--log_interval", type=int, default=100) | |
| parser.add_argument("-save", action="store_true") | |
| parser.add_argument("-gen_samples", action="store_true") | |
| parser.add_argument("--dataset", type=str, default='LATENT_BLOCK') | |
| parser.add_argument("--num_workers", type=int, default=0) | |
| parser.add_argument("--img_dim", type=int, default=64) | |
| parser.add_argument("--input_dim", type=int, default=1, | |
| help='1 for grayscale 3 for rgb') | |
| parser.add_argument("--n_embeddings", type=int, default=3, | |
| help='number of embeddings from VQ VAE') | |
| parser.add_argument("--n_layers", type=int, default=5) | |
| parser.add_argument("--learning_rate", type=float, default=3e-4) | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| """ | |
| data loaders | |
| """ | |
| _, _, train_loader, test_loader, _ = utils.load_data_and_data_loaders('LATENT_BLOCK', args.batch_size) | |
| model = GatedPixelCNN(args.n_embeddings, args.img_dim**2, args.n_layers).to(device) | |
| criterion = nn.CrossEntropyLoss().cuda() | |
| opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | |
| """ | |
| train, test, and log | |
| """ | |
| def train(): | |
| train_loss = [] | |
| for batch_idx, (x, label) in enumerate(train_loader): | |
| start_time = time.time() | |
| # if args.dataset == 'LATENT_BLOCK': | |
| # x = (x[:, 0]).cuda() | |
| # else: | |
| # x = (x[:, 0] * (K-1)).long().cuda() | |
| x = x.cuda() | |
| label = label.cuda() | |
| # Train PixelCNN with images | |
| logits = model(x, label) | |
| print(logits.shape) | |
| exit(0) | |
| logits = logits.permute(0, 2, 3, 1).contiguous() | |
| loss = criterion( | |
| logits.view(-1, args.n_embeddings), | |
| x.view(-1) | |
| ) | |
| opt.zero_grad() | |
| loss.backward() | |
| opt.step() | |
| train_loss.append(loss.item()) | |
| if (batch_idx + 1) % args.log_interval == 0: | |
| print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format( | |
| batch_idx * len(x), len(train_loader.dataset), | |
| args.log_interval * batch_idx / len(train_loader), | |
| np.asarray(train_loss)[-args.log_interval:].mean(0), | |
| time.time() - start_time | |
| )) | |
| def test(): | |
| start_time = time.time() | |
| val_loss = [] | |
| with torch.no_grad(): | |
| for batch_idx, (x, label) in enumerate(test_loader): | |
| if args.dataset == 'LATENT_BLOCK': | |
| x = (x[:, 0]).cuda() | |
| else: | |
| x = (x[:, 0] * (args.n_embeddings-1)).long().cuda() | |
| label = label.cuda() | |
| logits = model(x, label) | |
| logits = logits.permute(0, 2, 3, 1).contiguous() | |
| loss = criterion( | |
| logits.view(-1, args.n_embeddings), | |
| x.view(-1) | |
| ) | |
| val_loss.append(loss.item()) | |
| print('Validation Completed!\tLoss: {} Time: {}'.format( | |
| np.asarray(val_loss).mean(0), | |
| time.time() - start_time | |
| )) | |
| return np.asarray(val_loss).mean(0) | |
| def generate_samples(epoch): | |
| label = torch.arange(10).expand(10, 10).contiguous().view(-1) | |
| label = label.long().cuda() | |
| x_tilde = model.generate(label, shape=(args.img_dim,args.img_dim), batch_size=100) | |
| print(x_tilde[0]) | |
| BEST_LOSS = 999 | |
| LAST_SAVED = -1 | |
| for epoch in range(1, args.epochs): | |
| print("\nEpoch {}:".format(epoch)) | |
| train() | |
| cur_loss = test() | |
| if args.save or cur_loss <= BEST_LOSS: | |
| BEST_LOSS = cur_loss | |
| LAST_SAVED = epoch | |
| print("Saving model!") | |
| torch.save(model.state_dict(), 'results/{}_pixelcnn.pt'.format(args.dataset)) | |
| else: | |
| print("Not saving model! Last saved: {}".format(LAST_SAVED)) | |
| if args.gen_samples: | |
| generate_samples(epoch) | |