Instructions to use SrinivasMudiraj/Baaz with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SrinivasMudiraj/Baaz with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SrinivasMudiraj/Baaz", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import print_function | |
| from miscc.utils import mkdir_p | |
| from miscc.utils import build_super_images | |
| from miscc.losses import sent_loss, words_loss | |
| from miscc.config import cfg, cfg_from_file | |
| from datasets import prepare_data, TextBertDataset | |
| from DAMSM import CNN_ENCODER, BERT_RNN_ENCODER | |
| from torch.utils.tensorboard import SummaryWriter | |
| import os | |
| import sys | |
| import time | |
| import random | |
| import pprint | |
| import datetime | |
| import dateutil.tz | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.autograd import Variable | |
| import torch.backends.cudnn as cudnn | |
| import torchvision.transforms as transforms | |
| dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) | |
| sys.path.append(dir_path) | |
| UPDATE_INTERVAL = 200 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train a DAMSM network') | |
| parser.add_argument('--cfg', dest='cfg_file', | |
| help='optional config file', | |
| default='cfg/DAMSM/bird.yaml', type=str) | |
| parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) | |
| parser.add_argument('--data_dir', dest='data_dir', type=str, default='data/birds') | |
| parser.add_argument('--manualSeed', type=int, default=0, help='manual seed') | |
| args = parser.parse_args() | |
| return args | |
| def train(dataloader, cnn_model, rnn_model, batch_size, | |
| labels, optimizer, epoch, ixtoword, image_dir, writer): | |
| cnn_model.train() | |
| rnn_model.train() | |
| s_total_loss = 0 | |
| w_total_loss = 0 | |
| s_total_loss0 = 0 | |
| s_total_loss1 = 0 | |
| w_total_loss0 = 0 | |
| w_total_loss1 = 0 | |
| count = (epoch + 1) * len(dataloader) | |
| start_time = time.time() | |
| for step, data in enumerate(dataloader, 0): | |
| # print('step', step) | |
| rnn_model.zero_grad() | |
| cnn_model.zero_grad() | |
| imgs, captions, cap_lens, \ | |
| class_ids, keys = prepare_data(data) | |
| # words_features: batch_size x nef x 17 x 17 | |
| # sent_code: batch_size x nef | |
| words_features, sent_code = cnn_model(imgs[-1]) | |
| # --> batch_size x nef x 17*17 | |
| nef, att_sze = words_features.size(1), words_features.size(2) | |
| # words_features = words_features.view(batch_size, nef, -1) | |
| hidden = rnn_model.init_hidden(batch_size) | |
| # words_emb: batch_size x nef x seq_len | |
| # sent_emb: batch_size x nef | |
| words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) | |
| w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, | |
| cap_lens, class_ids, batch_size) | |
| w_total_loss0 += w_loss0.data | |
| w_total_loss1 += w_loss1.data | |
| loss = w_loss0 + w_loss1 | |
| w_total_loss += (w_loss0 + w_loss1).data | |
| s_loss0, s_loss1 = \ | |
| sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) | |
| loss += s_loss0 + s_loss1 | |
| s_total_loss0 += s_loss0.data | |
| s_total_loss1 += s_loss1.data | |
| s_total_loss += (s_loss0 + s_loss1).data | |
| # | |
| loss.backward() | |
| # | |
| # `clip_grad_norm` helps prevent | |
| # the exploding gradient problem in RNNs / LSTMs. | |
| torch.nn.utils.clip_grad_norm_(rnn_model.parameters(), | |
| cfg.TRAIN.RNN_GRAD_CLIP) | |
| optimizer.step() | |
| print('| epoch {:3d} | {:5d}/{:5d} batches |'.format(epoch, step, len(dataloader))) | |
| if step % UPDATE_INTERVAL == 0: | |
| count = epoch * len(dataloader) + step | |
| s_cur_loss0 = s_total_loss0.item() / UPDATE_INTERVAL | |
| s_cur_loss1 = s_total_loss1.item() / UPDATE_INTERVAL | |
| w_cur_loss0 = w_total_loss0.item() / UPDATE_INTERVAL | |
| w_cur_loss1 = w_total_loss1.item() / UPDATE_INTERVAL | |
| elapsed = time.time() - start_time | |
| print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' | |
| 's_loss {:5.2f} {:5.2f} | ' | |
| 'w_loss {:5.2f} {:5.2f}' | |
| .format(epoch, step, len(dataloader), | |
| elapsed * 1000. / UPDATE_INTERVAL, | |
| s_cur_loss0, s_cur_loss1, | |
| w_cur_loss0, w_cur_loss1)) | |
| s_total_loss0 = 0 | |
| s_total_loss1 = 0 | |
| w_total_loss0 = 0 | |
| w_total_loss1 = 0 | |
| start_time = time.time() | |
| # attention Maps | |
| img_set, _ = \ | |
| build_super_images(imgs[-1].cpu(), captions, | |
| ixtoword, attn_maps, att_sze) | |
| if img_set is not None: | |
| im = Image.fromarray(img_set) | |
| fullpath = '%s/attention_maps%d.png' % (image_dir, step) | |
| im.save(fullpath) | |
| if len(dataloader) == 0: | |
| return count | |
| s_total_loss = s_total_loss.item() / step | |
| w_total_loss = w_total_loss.item() / step | |
| writer.add_scalar('s_loss/train', s_total_loss, epoch) | |
| writer.add_scalar('w_loss/train', w_total_loss, epoch) | |
| return count | |
| def evaluate(dataloader, cnn_model, rnn_model, batch_size): | |
| cnn_model.eval() | |
| rnn_model.eval() | |
| s_total_loss = 0 | |
| w_total_loss = 0 | |
| for step, data in enumerate(dataloader, 0): | |
| real_imgs, captions, cap_lens, \ | |
| class_ids, keys = prepare_data(data) | |
| words_features, sent_code = cnn_model(real_imgs[-1]) | |
| # nef = words_features.size(1) | |
| # words_features = words_features.view(batch_size, nef, -1) | |
| hidden = rnn_model.init_hidden(batch_size) | |
| words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) | |
| w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, | |
| cap_lens, class_ids, batch_size) | |
| w_total_loss += (w_loss0 + w_loss1).data | |
| s_loss0, s_loss1 = \ | |
| sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) | |
| s_total_loss += (s_loss0 + s_loss1).data | |
| print('{:5d}/{:5d} batches |'.format(step, len(dataloader))) | |
| # if step == 50: | |
| # break | |
| s_cur_loss = s_total_loss.item() / step | |
| w_cur_loss = w_total_loss.item() / step | |
| return s_cur_loss, w_cur_loss | |
| def build_models(): | |
| # build model ############################################################ | |
| text_encoder = BERT_RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) | |
| image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) | |
| labels = Variable(torch.LongTensor(range(batch_size))) | |
| start_epoch = 0 | |
| if cfg.TRAIN.NET_E != '': | |
| state_dict = torch.load(cfg.TRAIN.NET_E) | |
| text_encoder.load_state_dict(state_dict) | |
| print('Load ', cfg.TRAIN.NET_E) | |
| # | |
| name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') | |
| state_dict = torch.load(name) | |
| image_encoder.load_state_dict(state_dict) | |
| print('Load ', name) | |
| istart = cfg.TRAIN.NET_E.rfind('_') + 8 | |
| iend = cfg.TRAIN.NET_E.rfind('.') | |
| start_epoch = cfg.TRAIN.NET_E[istart:iend] | |
| start_epoch = int(start_epoch) + 1 | |
| print('start_epoch', start_epoch) | |
| if cfg.CUDA: | |
| text_encoder = text_encoder.cuda() | |
| image_encoder = image_encoder.cuda() | |
| labels = labels.cuda() | |
| return text_encoder, image_encoder, labels, start_epoch | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.cfg_file is not None: | |
| cfg_from_file(args.cfg_file) | |
| if args.gpu_id == -1: | |
| cfg.CUDA = False | |
| else: | |
| cfg.GPU_ID = args.gpu_id | |
| if args.data_dir != '': | |
| cfg.DATA_DIR = args.data_dir | |
| print('Using config:') | |
| pprint.pprint(cfg) | |
| if not cfg.TRAIN.FLAG: | |
| args.manualSeed = 100 | |
| elif args.manualSeed is None: | |
| args.manualSeed = random.randint(1, 10000) | |
| random.seed(args.manualSeed) | |
| np.random.seed(args.manualSeed) | |
| torch.manual_seed(args.manualSeed) | |
| if cfg.CUDA: | |
| torch.cuda.manual_seed_all(args.manualSeed) | |
| ########################################################################## | |
| output_dir = 'output/%s/%s' % \ | |
| (cfg.DATASET_NAME, cfg.CONFIG_NAME) | |
| model_dir = os.path.join(output_dir, 'Models') | |
| image_dir = os.path.join(output_dir, 'Image') | |
| mkdir_p(model_dir) | |
| mkdir_p(image_dir) | |
| torch.cuda.set_device(cfg.GPU_ID) | |
| cudnn.benchmark = True | |
| # Get data loader ################################################## | |
| imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) | |
| batch_size = cfg.TRAIN.BATCH_SIZE | |
| image_transform = transforms.Compose([ | |
| transforms.Scale(int(imsize * 76 / 64)), | |
| transforms.RandomCrop(imsize), | |
| transforms.RandomHorizontalFlip()]) | |
| dataset = TextBertDataset(cfg.DATA_DIR, 'train', | |
| base_size=cfg.TREE.BASE_SIZE, | |
| transform=image_transform) | |
| print(dataset.n_words, dataset.embeddings_num) | |
| assert dataset | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, drop_last=True, | |
| shuffle=True, num_workers=int(cfg.WORKERS)) | |
| # validation data # | |
| dataset_val = TextBertDataset(cfg.DATA_DIR, 'test', | |
| base_size=cfg.TREE.BASE_SIZE, | |
| transform=image_transform) | |
| dataloader_val = torch.utils.data.DataLoader( | |
| dataset_val, batch_size=batch_size, drop_last=True, | |
| shuffle=True, num_workers=int(cfg.WORKERS)) | |
| # Train ############################################################## | |
| text_encoder, image_encoder, labels, start_epoch = build_models() | |
| lr = cfg.TRAIN.ENCODER_LR | |
| if os.path.exists('%s/checkpoint.pth' % (model_dir)): | |
| checkpoint = torch.load('%s/checkpoint.pth' % (model_dir)) | |
| text_encoder.load_state_dict(torch.load('%s/text_encoder.pth' % (model_dir))) | |
| image_encoder.load_state_dict(torch.load('%s/image_encoder.pth' % (model_dir))) | |
| start_epoch = checkpoint['epoch'] | |
| lr = checkpoint['lr'] | |
| text_encoder.train() | |
| image_encoder.train() | |
| print("Loading last checkpoint at epoch: ",start_epoch) | |
| else: | |
| print("No checkpoint to load") | |
| writer = SummaryWriter('tensorboards/%s/%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME)) | |
| para = list(text_encoder.parameters()) | |
| for v in image_encoder.parameters(): | |
| if v.requires_grad: | |
| para.append(v) | |
| # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) | |
| # At any point you can hit Ctrl + C to break out of training early. | |
| try: | |
| # lr = cfg.TRAIN.ENCODER_LR | |
| for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): | |
| optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) | |
| epoch_start_time = time.time() | |
| count = train(dataloader, image_encoder, text_encoder, | |
| batch_size, labels, optimizer, epoch, | |
| dataset.ixtoword, image_dir, writer) | |
| print('-' * 89) | |
| if len(dataloader_val) > 0 and epoch % 10 == 0: | |
| s_loss, w_loss = evaluate(dataloader_val, image_encoder, | |
| text_encoder, batch_size) | |
| print('| end epoch {:3d} | valid loss ' | |
| '{:5.2f} {:5.2f} | lr {:.5f}|' | |
| .format(epoch, s_loss, w_loss, lr)) | |
| writer.add_scalar('s_loss/val', s_loss, epoch) | |
| writer.add_scalar('w_loss/val', w_loss, epoch) | |
| print('-' * 89) | |
| if lr > cfg.TRAIN.ENCODER_LR/10.: | |
| lr *= 0.98 | |
| if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or | |
| epoch == cfg.TRAIN.MAX_EPOCH): | |
| torch.save(image_encoder.state_dict(), | |
| '%s/image_encoder.pth' % (model_dir)) | |
| torch.save(text_encoder.state_dict(), | |
| '%s/text_encoder.pth' % (model_dir)) | |
| torch.save({'epoch' : epoch, 'lr' : lr}, | |
| '%s/checkpoint.pth' % (model_dir)) | |
| print('Save G/Ds models.') | |
| except KeyboardInterrupt: | |
| print('-' * 89) | |
| print('Exiting from training early') | |