Instructions to use SrinivasMudiraj/Baaz with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use SrinivasMudiraj/Baaz with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("SrinivasMudiraj/Baaz", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import print_function | |
| from miscc.utils import mkdir_p | |
| from miscc.config import cfg, cfg_from_file | |
| from datasets import prepare_data, TextBertDataset | |
| from eval.IS.bird.inception_score_bird import compute_IS | |
| from eval.FID.fid_score import compute_FID | |
| from DAMSM import BERT_RNN_ENCODER | |
| from transformers import AutoTokenizer, AutoModel | |
| import os | |
| import sys | |
| import time | |
| import random | |
| import pprint | |
| import datetime | |
| import dateutil.tz | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.autograd import Variable | |
| import torch.backends.cudnn as cudnn | |
| import torchvision.transforms as transforms | |
| from model import NetG,NetD | |
| from torchvision.models import inception_v3 | |
| import torchvision.utils as vutils | |
| from torch.utils.tensorboard import SummaryWriter | |
| dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) | |
| sys.path.append(dir_path) | |
| import multiprocessing | |
| multiprocessing.set_start_method('spawn', True) | |
| import os | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" # or "n" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| UPDATE_INTERVAL = 200 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train a DAMSM network') | |
| parser.add_argument('--cfg', dest='cfg_file', | |
| help='optional config file', | |
| default='cfg/bird.yml', type=str) | |
| parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) | |
| parser.add_argument('--data_dir', dest='data_dir', type=str, default='') | |
| parser.add_argument('--manualSeed', type=int, help='manual seed') | |
| parser.add_argument('--evaluation', type=int, help='evaluation', default= 0) | |
| args = parser.parse_args() | |
| return args | |
| def sampling(text_encoder, netG, dataloader,device, validation= False): | |
| state_epoch = 0 | |
| model_dir = '../models/%s/checkpoint_nets.pth' % (cfg.CONFIG_NAME) | |
| if(not validation and os.path.exists(model_dir)): | |
| checkpoint = torch.load(model_dir) | |
| netG.load_state_dict(checkpoint['netG_state']) | |
| state_epoch = checkpoint['epoch'] | |
| netG.eval() | |
| print("loading last checkpoint at epoch: ",state_epoch) | |
| batch_size = cfg.TRAIN.BATCH_SIZE | |
| save_dir = '../images/%s/test' % (cfg.CONFIG_NAME) | |
| mkdir_p(save_dir) | |
| cnt = 0 | |
| for i in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): | |
| for step, data in enumerate(dataloader, 0): | |
| imags, captions, cap_lens, class_ids, keys = prepare_data(data) | |
| cnt += batch_size | |
| hidden = text_encoder.init_hidden(batch_size) | |
| # words_embs: batch_size x nef x seq_len | |
| # sent_emb: batch_size x nef | |
| words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) | |
| words_embs, sent_emb = words_embs.detach(), sent_emb.detach() | |
| ####################################################### | |
| # (2) Generate fake images | |
| ###################################################### | |
| with torch.no_grad(): | |
| noise = torch.randn(batch_size, 100) | |
| noise=noise.to(device) | |
| fake_imgs = netG(noise,sent_emb) | |
| for j in range(batch_size): | |
| s_tmp = '%s/%s' % (save_dir, keys[j]) | |
| folder = s_tmp[:s_tmp.rfind('/')] | |
| if not os.path.isdir(folder): | |
| print('Make a new folder: ', folder) | |
| mkdir_p(folder) | |
| im = fake_imgs[j].data.cpu().numpy() | |
| # [-1, 1] --> [0, 255] | |
| im = (im + 1.0) * 127.5 | |
| im = im.astype(np.uint8) | |
| im = np.transpose(im, (1, 2, 0)) | |
| im = Image.fromarray(im) | |
| fullpath = '%s_%3d.png' % (s_tmp,i) | |
| im.save(fullpath) | |
| return state_epoch | |
| def validate(text_encoder, netG,device, writer, epoch): | |
| dataset = TextBertDataset(cfg.DATA_DIR, 'test', | |
| base_size=cfg.TREE.BASE_SIZE, | |
| transform=image_transform) | |
| print(dataset.n_words, dataset.embeddings_num) | |
| assert dataset | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, drop_last=True, | |
| shuffle=True, num_workers=int(cfg.WORKERS)) | |
| print(f'Starting generate validation images ... at {epoch}') | |
| sampling(text_encoder, netG, dataloader, device, validation= True) | |
| netG.train() | |
| print(f'Starting compute FID & IS ... at {epoch}') | |
| compute_FID(['/home/icmr/Srinivas/PhD/Text to Image/Hindi/Arabic-text-visualization-using-ADF-GAN/data/CUB-200/CUB-200_val.npz', | |
| '../images/%s/test' % (cfg.CONFIG_NAME)], writer, epoch) | |
| mean, std = compute_IS('../images/%s/test' % (cfg.CONFIG_NAME), writer, epoch) | |
| final_score = mean / std | |
| print(f"Final Inception Score: {final_score:.4f}") | |
| ######################################### | |
| ######################################## | |
| def train(dataloader,netG,netD,text_encoder,optimizerG,optimizerD,state_epoch,batch_size,device, writer): | |
| path = '../models/%s/checkpoint_nets.pth' % (cfg.CONFIG_NAME) | |
| if(os.path.exists(path)): | |
| checkpoint = torch.load(path) | |
| netG.load_state_dict(checkpoint['netG_state']) | |
| netD.load_state_dict(checkpoint['netD_state']) | |
| optimizerG.load_state_dict(checkpoint['optimizerG_state']) | |
| optimizerD.load_state_dict(checkpoint['optimizerD_state']) | |
| state_epoch = checkpoint['epoch'] | |
| netG.train() | |
| netD.train() | |
| print("Loading last checkpoint at epoch: ",state_epoch) | |
| else: | |
| print("No checkpoint to load") | |
| for epoch in range(state_epoch+1, cfg.TRAIN.MAX_EPOCH+1): | |
| D_loss = 0.0 | |
| G_loss = 0.0 | |
| for step, data in enumerate(dataloader, 0): | |
| imags, captions, cap_lens, class_ids, keys = prepare_data(data) | |
| hidden = text_encoder.init_hidden(batch_size) | |
| # words_embs: batch_size x nef x seq_len | |
| # sent_emb: batch_size x nef | |
| words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) | |
| words_embs, sent_emb = words_embs.detach(), sent_emb.detach() | |
| imgs=imags[0].to(device) | |
| real_features = netD(imgs) | |
| output = netD.COND_DNET(real_features,sent_emb) | |
| errD_real = torch.nn.ReLU()(1.0 - output).mean() | |
| output = netD.COND_DNET(real_features[:(batch_size - 1)], sent_emb[1:batch_size]) | |
| errD_mismatch = torch.nn.ReLU()(1.0 + output).mean() | |
| # synthesize fake images | |
| noise = torch.randn(batch_size, 100) | |
| noise=noise.to(device) | |
| fake = netG(noise,sent_emb) | |
| # G does not need update with D | |
| fake_features = netD(fake.detach()) | |
| errD_fake = netD.COND_DNET(fake_features,sent_emb) | |
| errD_fake = torch.nn.ReLU()(1.0 + errD_fake).mean() | |
| errD = errD_real + (errD_fake + errD_mismatch)/2.0 | |
| optimizerD.zero_grad() | |
| optimizerG.zero_grad() | |
| errD.backward() | |
| optimizerD.step() | |
| #MA-GP | |
| interpolated = (imgs.data).requires_grad_() | |
| sent_inter = (sent_emb.data).requires_grad_() | |
| features = netD(interpolated) | |
| out = netD.COND_DNET(features,sent_inter) | |
| grads = torch.autograd.grad(outputs=out, | |
| inputs=(interpolated,sent_inter), | |
| grad_outputs=torch.ones(out.size()).cuda(), | |
| retain_graph=True, | |
| create_graph=True, | |
| only_inputs=True) | |
| grad0 = grads[0].view(grads[0].size(0), -1) | |
| grad1 = grads[1].view(grads[1].size(0), -1) | |
| grad = torch.cat((grad0,grad1),dim=1) | |
| grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) | |
| d_loss_gp = torch.mean((grad_l2norm) ** 6) | |
| d_loss = 2.0 * d_loss_gp | |
| optimizerD.zero_grad() | |
| optimizerG.zero_grad() | |
| d_loss.backward() | |
| optimizerD.step() | |
| # update G | |
| features = netD(fake) | |
| output = netD.COND_DNET(features,sent_emb) | |
| errG = - output.mean() | |
| optimizerG.zero_grad() | |
| optimizerD.zero_grad() | |
| errG.backward() | |
| optimizerG.step() | |
| D_loss += errD.item() + d_loss.item() | |
| G_loss += errG.item() | |
| print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f total_Loss_D: %.3f total_Loss_G %.3f' | |
| % (epoch, cfg.TRAIN.MAX_EPOCH, step, len(dataloader), errD.item(), errG.item(), D_loss, G_loss)) | |
| vutils.save_image(fake.data, | |
| '../images/%s/fakes/fake_samples_epoch_%03d.png' % (cfg.CONFIG_NAME, epoch), | |
| normalize=True) | |
| # if epoch%10==0: | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'netG_state': netG.state_dict(), | |
| 'optimizerG_state': optimizerG.state_dict(), | |
| 'netD_state': netD.state_dict(), | |
| 'optimizerD_state': optimizerD.state_dict() | |
| }, path) | |
| writer.add_scalar('D_Loss/train', D_loss, epoch) | |
| writer.add_scalar('G_Loss/train', G_loss, epoch) | |
| if epoch%50 == 0: | |
| return epoch | |
| return cfg.TRAIN.MAX_EPOCH | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.cfg_file is not None: | |
| cfg_from_file(args.cfg_file) | |
| if args.gpu_id == -1: | |
| cfg.CUDA = False | |
| else: | |
| cfg.GPU_ID = args.gpu_id | |
| if args.data_dir != '': | |
| cfg.DATA_DIR = args.data_dir | |
| cfg.B_VALIDATION = bool(args.evaluation) | |
| print('Using config:') | |
| pprint.pprint(cfg) | |
| if not cfg.TRAIN.FLAG: | |
| args.manualSeed = 100 | |
| elif args.manualSeed is None: | |
| args.manualSeed = 100 | |
| #args.manualSeed = random.randint(1, 10000) | |
| print("seed now is : ",args.manualSeed) | |
| random.seed(args.manualSeed) | |
| np.random.seed(args.manualSeed) | |
| torch.manual_seed(args.manualSeed) | |
| if cfg.CUDA: | |
| torch.cuda.manual_seed_all(args.manualSeed) | |
| ########################################################################## | |
| torch.cuda.set_device(cfg.GPU_ID) | |
| cudnn.benchmark = True | |
| # Get data loader ################################################## | |
| imsize = cfg.TREE.BASE_SIZE | |
| batch_size = cfg.TRAIN.BATCH_SIZE | |
| image_transform = transforms.Compose([ | |
| transforms.Resize(int(imsize * 76 / 64)), | |
| transforms.RandomCrop(imsize), | |
| transforms.RandomHorizontalFlip()]) | |
| if cfg.B_VALIDATION: | |
| dataset = TextBertDataset(cfg.DATA_DIR, 'test', | |
| base_size=cfg.TREE.BASE_SIZE, | |
| transform=image_transform) | |
| print(dataset.n_words, dataset.embeddings_num) | |
| assert dataset | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, drop_last=True, | |
| shuffle=True, num_workers=int(cfg.WORKERS)) | |
| else: | |
| dataset = TextBertDataset(cfg.DATA_DIR, 'train', | |
| base_size=cfg.TREE.BASE_SIZE, | |
| transform=image_transform) | |
| print(dataset.n_words, dataset.embeddings_num) | |
| assert dataset | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=batch_size, drop_last=True, | |
| shuffle=True, num_workers=int(cfg.WORKERS)) | |
| # # validation data # | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| netG = NetG(cfg.TRAIN.NF, 100).to(device) | |
| netD = NetD(cfg.TRAIN.NF).to(device) | |
| text_encoder = BERT_RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) | |
| state_dict = torch.load(cfg.TEXT.DAMSM_NAME, map_location=lambda storage, loc: storage) | |
| state_dict.pop('encoder.embeddings.word_embeddings.weight', None) | |
| text_encoder.load_state_dict(state_dict, strict=False) | |
| text_encoder.cuda() | |
| for p in text_encoder.parameters(): | |
| p.requires_grad = False | |
| text_encoder.eval() | |
| state_epoch=0 | |
| optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0001, betas=(0.0, 0.9)) | |
| optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0004, betas=(0.0, 0.9)) | |
| if cfg.B_VALIDATION: | |
| state_epoch = sampling(text_encoder, netG, dataloader,device) # generate images for the whole valid dataset | |
| print('state_epoch: %d'%(state_epoch)) | |
| else: | |
| writer = SummaryWriter(f"tensorboards/{cfg.CONFIG_NAME}/ADGAN_train") | |
| epoch = train(dataloader,netG,netD,text_encoder,optimizerG,optimizerD, state_epoch,batch_size,device, writer) | |
| validate(text_encoder, netG, device, writer, epoch) | |