| | import argparse |
| | import numpy as np |
| | import random |
| | from pathlib import Path |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.backends.cudnn as cudnn |
| | import torch.distributed as dist |
| | from torch.cuda.amp import GradScaler, autocast |
| | from models.FFLIP import FLIP |
| | from models import utils |
| | from eval.pretrain_eval import evaluation, itm_eval |
| | from data import create_dataset, create_sampler, create_loader |
| |
|
| | def main(args): |
| |
|
| | utils.init_distributed_mode(args) |
| | device = torch.device(args.device) |
| | seed = args.seed + utils.get_rank() |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | random.seed(seed) |
| | cudnn.benchmark = True |
| |
|
| | |
| |
|
| | print("Creating dataset") |
| | train_dataset, test_dataset = create_dataset(args, 'facecaption') |
| |
|
| | if args.distributed: |
| | num_tasks = utils.get_world_size() |
| | global_rank = utils.get_rank() |
| | samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None] |
| | else: |
| | samplers = [None, None] |
| |
|
| | train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers, |
| | batch_size=[80] + [80], |
| | num_workers=[8, 8], |
| | is_trains=[True, False], |
| | collate_fns=[None, None]) |
| | |
| | print("Creating model") |
| | model = FLIP(pretrained=args.pretrained, vit='base', queue_size=61440) |
| |
|
| | model = model.to(device) |
| | |
| | model_without_ddp = model |
| | if args.distributed: |
| | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
| | model_without_ddp = model.module |
| |
|
| | print("Start evaluation") |
| | score_test_i2t, score_test_t2i = evaluation(args, model_without_ddp, test_loader, device) |
| | |
| | if utils.is_main_process(): |
| | test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, |
| | test_loader.dataset.img2txt) |
| | print(test_result) |
| |
|
| | |
| | if args.distributed: |
| | dist.barrier() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--output_dir', default='./outputs') |
| | parser.add_argument('--img_root', default='./FaceCaption/images') |
| | parser.add_argument('--ann_root', default='.FaceCaption/caption') |
| | parser.add_argument('--pretrained', default='./FaceCaption-15M-base.pth') |
| | parser.add_argument('--device', default='cuda') |
| | parser.add_argument('--seed', default=42, type=int) |
| | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') |
| | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
| | parser.add_argument('--distributed', default=False, type=bool, help='whether to use distributed mode to training') |
| | args = parser.parse_args() |
| | |
| | main(args) |