Baaz / code /main.py
SrinivasMudiraj's picture
Upload 34 files
9c7aa86 verified
from __future__ import print_function
from miscc.utils import mkdir_p
from miscc.config import cfg, cfg_from_file
from datasets import prepare_data, TextBertDataset
from eval.IS.bird.inception_score_bird import compute_IS
from eval.FID.fid_score import compute_FID
from DAMSM import BERT_RNN_ENCODER
from transformers import AutoTokenizer, AutoModel
import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from model import NetG,NetD
from torchvision.models import inception_v3
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
sys.path.append(dir_path)
import multiprocessing
multiprocessing.set_start_method('spawn', True)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # or "n"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
UPDATE_INTERVAL = 200
def parse_args():
parser = argparse.ArgumentParser(description='Train a DAMSM network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default='cfg/bird.yml', type=str)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)
parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--evaluation', type=int, help='evaluation', default= 0)
args = parser.parse_args()
return args
def sampling(text_encoder, netG, dataloader,device, validation= False):
state_epoch = 0
model_dir = '../models/%s/checkpoint_nets.pth' % (cfg.CONFIG_NAME)
if(not validation and os.path.exists(model_dir)):
checkpoint = torch.load(model_dir)
netG.load_state_dict(checkpoint['netG_state'])
state_epoch = checkpoint['epoch']
netG.eval()
print("loading last checkpoint at epoch: ",state_epoch)
batch_size = cfg.TRAIN.BATCH_SIZE
save_dir = '../images/%s/test' % (cfg.CONFIG_NAME)
mkdir_p(save_dir)
cnt = 0
for i in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE):
for step, data in enumerate(dataloader, 0):
imags, captions, cap_lens, class_ids, keys = prepare_data(data)
cnt += batch_size
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
#######################################################
# (2) Generate fake images
######################################################
with torch.no_grad():
noise = torch.randn(batch_size, 100)
noise=noise.to(device)
fake_imgs = netG(noise,sent_emb)
for j in range(batch_size):
s_tmp = '%s/%s' % (save_dir, keys[j])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
mkdir_p(folder)
im = fake_imgs[j].data.cpu().numpy()
# [-1, 1] --> [0, 255]
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
fullpath = '%s_%3d.png' % (s_tmp,i)
im.save(fullpath)
return state_epoch
def validate(text_encoder, netG,device, writer, epoch):
dataset = TextBertDataset(cfg.DATA_DIR, 'test',
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
print(dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(cfg.WORKERS))
print(f'Starting generate validation images ... at {epoch}')
sampling(text_encoder, netG, dataloader, device, validation= True)
netG.train()
print(f'Starting compute FID & IS ... at {epoch}')
compute_FID(['/home/icmr/Srinivas/PhD/Text to Image/Hindi/Arabic-text-visualization-using-ADF-GAN/data/CUB-200/CUB-200_val.npz',
'../images/%s/test' % (cfg.CONFIG_NAME)], writer, epoch)
mean, std = compute_IS('../images/%s/test' % (cfg.CONFIG_NAME), writer, epoch)
final_score = mean / std
print(f"Final Inception Score: {final_score:.4f}")
#########################################
########################################
def train(dataloader,netG,netD,text_encoder,optimizerG,optimizerD,state_epoch,batch_size,device, writer):
path = '../models/%s/checkpoint_nets.pth' % (cfg.CONFIG_NAME)
if(os.path.exists(path)):
checkpoint = torch.load(path)
netG.load_state_dict(checkpoint['netG_state'])
netD.load_state_dict(checkpoint['netD_state'])
optimizerG.load_state_dict(checkpoint['optimizerG_state'])
optimizerD.load_state_dict(checkpoint['optimizerD_state'])
state_epoch = checkpoint['epoch']
netG.train()
netD.train()
print("Loading last checkpoint at epoch: ",state_epoch)
else:
print("No checkpoint to load")
for epoch in range(state_epoch+1, cfg.TRAIN.MAX_EPOCH+1):
D_loss = 0.0
G_loss = 0.0
for step, data in enumerate(dataloader, 0):
imags, captions, cap_lens, class_ids, keys = prepare_data(data)
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
imgs=imags[0].to(device)
real_features = netD(imgs)
output = netD.COND_DNET(real_features,sent_emb)
errD_real = torch.nn.ReLU()(1.0 - output).mean()
output = netD.COND_DNET(real_features[:(batch_size - 1)], sent_emb[1:batch_size])
errD_mismatch = torch.nn.ReLU()(1.0 + output).mean()
# synthesize fake images
noise = torch.randn(batch_size, 100)
noise=noise.to(device)
fake = netG(noise,sent_emb)
# G does not need update with D
fake_features = netD(fake.detach())
errD_fake = netD.COND_DNET(fake_features,sent_emb)
errD_fake = torch.nn.ReLU()(1.0 + errD_fake).mean()
errD = errD_real + (errD_fake + errD_mismatch)/2.0
optimizerD.zero_grad()
optimizerG.zero_grad()
errD.backward()
optimizerD.step()
#MA-GP
interpolated = (imgs.data).requires_grad_()
sent_inter = (sent_emb.data).requires_grad_()
features = netD(interpolated)
out = netD.COND_DNET(features,sent_inter)
grads = torch.autograd.grad(outputs=out,
inputs=(interpolated,sent_inter),
grad_outputs=torch.ones(out.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True)
grad0 = grads[0].view(grads[0].size(0), -1)
grad1 = grads[1].view(grads[1].size(0), -1)
grad = torch.cat((grad0,grad1),dim=1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm) ** 6)
d_loss = 2.0 * d_loss_gp
optimizerD.zero_grad()
optimizerG.zero_grad()
d_loss.backward()
optimizerD.step()
# update G
features = netD(fake)
output = netD.COND_DNET(features,sent_emb)
errG = - output.mean()
optimizerG.zero_grad()
optimizerD.zero_grad()
errG.backward()
optimizerG.step()
D_loss += errD.item() + d_loss.item()
G_loss += errG.item()
print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f total_Loss_D: %.3f total_Loss_G %.3f'
% (epoch, cfg.TRAIN.MAX_EPOCH, step, len(dataloader), errD.item(), errG.item(), D_loss, G_loss))
vutils.save_image(fake.data,
'../images/%s/fakes/fake_samples_epoch_%03d.png' % (cfg.CONFIG_NAME, epoch),
normalize=True)
# if epoch%10==0:
torch.save({
'epoch': epoch,
'netG_state': netG.state_dict(),
'optimizerG_state': optimizerG.state_dict(),
'netD_state': netD.state_dict(),
'optimizerD_state': optimizerD.state_dict()
}, path)
writer.add_scalar('D_Loss/train', D_loss, epoch)
writer.add_scalar('G_Loss/train', G_loss, epoch)
if epoch%50 == 0:
return epoch
return cfg.TRAIN.MAX_EPOCH
if __name__ == "__main__":
args = parse_args()
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.gpu_id == -1:
cfg.CUDA = False
else:
cfg.GPU_ID = args.gpu_id
if args.data_dir != '':
cfg.DATA_DIR = args.data_dir
cfg.B_VALIDATION = bool(args.evaluation)
print('Using config:')
pprint.pprint(cfg)
if not cfg.TRAIN.FLAG:
args.manualSeed = 100
elif args.manualSeed is None:
args.manualSeed = 100
#args.manualSeed = random.randint(1, 10000)
print("seed now is : ",args.manualSeed)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
torch.cuda.manual_seed_all(args.manualSeed)
##########################################################################
torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True
# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
transforms.Resize(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
if cfg.B_VALIDATION:
dataset = TextBertDataset(cfg.DATA_DIR, 'test',
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
print(dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(cfg.WORKERS))
else:
dataset = TextBertDataset(cfg.DATA_DIR, 'train',
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
print(dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(cfg.WORKERS))
# # validation data #
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = NetG(cfg.TRAIN.NF, 100).to(device)
netD = NetD(cfg.TRAIN.NF).to(device)
text_encoder = BERT_RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = torch.load(cfg.TEXT.DAMSM_NAME, map_location=lambda storage, loc: storage)
state_dict.pop('encoder.embeddings.word_embeddings.weight', None)
text_encoder.load_state_dict(state_dict, strict=False)
text_encoder.cuda()
for p in text_encoder.parameters():
p.requires_grad = False
text_encoder.eval()
state_epoch=0
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0004, betas=(0.0, 0.9))
if cfg.B_VALIDATION:
state_epoch = sampling(text_encoder, netG, dataloader,device) # generate images for the whole valid dataset
print('state_epoch: %d'%(state_epoch))
else:
writer = SummaryWriter(f"tensorboards/{cfg.CONFIG_NAME}/ADGAN_train")
epoch = train(dataloader,netG,netD,text_encoder,optimizerG,optimizerD, state_epoch,batch_size,device, writer)
validate(text_encoder, netG, device, writer, epoch)