from __future__ import print_function from miscc.utils import mkdir_p from miscc.utils import build_super_images from miscc.losses import sent_loss, words_loss from miscc.config import cfg, cfg_from_file from datasets import prepare_data, TextBertDataset from DAMSM import CNN_ENCODER, BERT_RNN_ENCODER from torch.utils.tensorboard import SummaryWriter import os import sys import time import random import pprint import datetime import dateutil.tz import argparse import numpy as np from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable import torch.backends.cudnn as cudnn import torchvision.transforms as transforms dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) sys.path.append(dir_path) UPDATE_INTERVAL = 200 def parse_args(): parser = argparse.ArgumentParser(description='Train a DAMSM network') parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='cfg/DAMSM/bird.yaml', type=str) parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) parser.add_argument('--data_dir', dest='data_dir', type=str, default='data/birds') parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') args = parser.parse_args() return args def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir, writer): cnn_model.train() rnn_model.train() s_total_loss = 0 w_total_loss = 0 s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data s_total_loss += (s_loss0 + s_loss1).data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() print('| epoch {:3d} | {:5d}/{:5d} batches |'.format(epoch, step, len(dataloader))) if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) if len(dataloader) == 0: return count s_total_loss = s_total_loss.item() / step w_total_loss = w_total_loss.item() / step writer.add_scalar('s_loss/train', s_total_loss, epoch) writer.add_scalar('w_loss/train', w_total_loss, epoch) return count def evaluate(dataloader, cnn_model, rnn_model, batch_size): cnn_model.eval() rnn_model.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) words_features, sent_code = cnn_model(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data print('{:5d}/{:5d} batches |'.format(step, len(dataloader))) # if step == 50: # break s_cur_loss = s_total_loss.item() / step w_cur_loss = w_total_loss.item() / step return s_cur_loss, w_cur_loss def build_models(): # build model ############################################################ text_encoder = BERT_RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch if __name__ == "__main__": args = parse_args() if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.gpu_id == -1: cfg.CUDA = False else: cfg.GPU_ID = args.gpu_id if args.data_dir != '': cfg.DATA_DIR = args.data_dir print('Using config:') pprint.pprint(cfg) if not cfg.TRAIN.FLAG: args.manualSeed = 100 elif args.manualSeed is None: args.manualSeed = random.randint(1, 10000) random.seed(args.manualSeed) np.random.seed(args.manualSeed) torch.manual_seed(args.manualSeed) if cfg.CUDA: torch.cuda.manual_seed_all(args.manualSeed) ########################################################################## output_dir = 'output/%s/%s' % \ (cfg.DATASET_NAME, cfg.CONFIG_NAME) model_dir = os.path.join(output_dir, 'Models') image_dir = os.path.join(output_dir, 'Image') mkdir_p(model_dir) mkdir_p(image_dir) torch.cuda.set_device(cfg.GPU_ID) cudnn.benchmark = True # Get data loader ################################################## imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) batch_size = cfg.TRAIN.BATCH_SIZE image_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextBertDataset(cfg.DATA_DIR, 'train', base_size=cfg.TREE.BASE_SIZE, transform=image_transform) print(dataset.n_words, dataset.embeddings_num) assert dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) # validation data # dataset_val = TextBertDataset(cfg.DATA_DIR, 'test', base_size=cfg.TREE.BASE_SIZE, transform=image_transform) dataloader_val = torch.utils.data.DataLoader( dataset_val, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) # Train ############################################################## text_encoder, image_encoder, labels, start_epoch = build_models() lr = cfg.TRAIN.ENCODER_LR if os.path.exists('%s/checkpoint.pth' % (model_dir)): checkpoint = torch.load('%s/checkpoint.pth' % (model_dir)) text_encoder.load_state_dict(torch.load('%s/text_encoder.pth' % (model_dir))) image_encoder.load_state_dict(torch.load('%s/image_encoder.pth' % (model_dir))) start_epoch = checkpoint['epoch'] lr = checkpoint['lr'] text_encoder.train() image_encoder.train() print("Loading last checkpoint at epoch: ",start_epoch) else: print("No checkpoint to load") writer = SummaryWriter('tensorboards/%s/%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME)) para = list(text_encoder.parameters()) for v in image_encoder.parameters(): if v.requires_grad: para.append(v) # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) # At any point you can hit Ctrl + C to break out of training early. try: # lr = cfg.TRAIN.ENCODER_LR for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) epoch_start_time = time.time() count = train(dataloader, image_encoder, text_encoder, batch_size, labels, optimizer, epoch, dataset.ixtoword, image_dir, writer) print('-' * 89) if len(dataloader_val) > 0 and epoch % 10 == 0: s_loss, w_loss = evaluate(dataloader_val, image_encoder, text_encoder, batch_size) print('| end epoch {:3d} | valid loss ' '{:5.2f} {:5.2f} | lr {:.5f}|' .format(epoch, s_loss, w_loss, lr)) writer.add_scalar('s_loss/val', s_loss, epoch) writer.add_scalar('w_loss/val', w_loss, epoch) print('-' * 89) if lr > cfg.TRAIN.ENCODER_LR/10.: lr *= 0.98 if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch == cfg.TRAIN.MAX_EPOCH): torch.save(image_encoder.state_dict(), '%s/image_encoder.pth' % (model_dir)) torch.save(text_encoder.state_dict(), '%s/text_encoder.pth' % (model_dir)) torch.save({'epoch' : epoch, 'lr' : lr}, '%s/checkpoint.pth' % (model_dir)) print('Save G/Ds models.') except KeyboardInterrupt: print('-' * 89) print('Exiting from training early')