dianecy commited on
Commit
037cf3b
·
verified ·
1 Parent(s): 213d274

Upload ./ASDA/test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ASDA/test.py +230 -0
ASDA/test.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import random
5
+ import datetime
6
+ import matplotlib as mpl
7
+ mpl.use('Agg')
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.parallel
12
+ import torch.backends.cudnn as cudnn
13
+ import torch.distributed as dist
14
+ import torch.optim
15
+ import torch.utils.data.distributed
16
+ from torch.utils.data import DataLoader
17
+ from torchvision.transforms import Compose, ToTensor, Normalize
18
+
19
+ import torch.distributed as dist
20
+ import torch.multiprocessing as mp
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ import torch.utils.data.distributed
23
+
24
+ #import apex.amp as amp
25
+ from torch.cuda.amp import autocast as autocast
26
+
27
+ from model.model import *
28
+ from engine.engine_oiou import *
29
+
30
+ from dataset.data_loader_test import *
31
+ from utils.losses import *
32
+ from utils.parsing_metrics import *
33
+ from utils.utils import *
34
+ from utils.checkpoint import load_pretrain, load_resume
35
+ from utils.logger import setup_logger
36
+
37
+ def get_args():
38
+ parser = argparse.ArgumentParser(description='Dataloader test')
39
+ parser.add_argument('--gpu', default='2', help='gpu id')
40
+ parser.add_argument('--ngpu', default=2, type=int, help='gpu num')
41
+ parser.add_argument('--workers', default=4, type=int, help='num workers for data loading')
42
+ parser.add_argument('--seed', default=0, type=int, help='random seed')
43
+
44
+ parser.add_argument('--clip_model', default='ViT-B/16', type=str, help='clip model RN50 RN101 ViT-B/32')
45
+ parser.add_argument('--nb_epoch', default=32, type=int, help='training epoch')
46
+ parser.add_argument('--lr', default=0.000025, type=float, help='batch size 16 learning rate')
47
+ parser.add_argument('--power', default=0.1, type=float, help='lr poly power')
48
+ parser.add_argument('--steps', default=[15, 28], type=list, help='in which step lr decay by power')
49
+ parser.add_argument('--batch_size', default=16, type=int, help='batch size')
50
+ parser.add_argument('--size', default=416, type=int, help='image size')
51
+ parser.add_argument('--dataset', default='refcoco', type=str,
52
+ help='refcoco/refcoco+/refcocog/grefcoco')
53
+
54
+ parser.add_argument('--num_query', default=16, type=int, help='the number of query')
55
+ parser.add_argument('--w_seg', default=0.1, type=float, help='weight of the seg loss')
56
+ parser.add_argument('--w_coord', default=5, type=float, help='weight of the reg loss')
57
+ parser.add_argument('--tunelang', dest='tunelang', default=True, action='store_true', help='if finetune language model')
58
+ parser.add_argument('--anchor_imsize', default=416, type=int,
59
+ help='scale used to calculate anchors defined in model cfg file')
60
+ parser.add_argument('--data_root', type=str, default='./ln_data',
61
+ help='path to ReferIt splits data folder')
62
+ parser.add_argument('--split_root', type=str, default='./data',
63
+ help='location of pre-parsed dataset info')
64
+ parser.add_argument('--time', default=15, type=int,
65
+ help='maximum time steps (lang length) per batch')
66
+ parser.add_argument('--log_dir', type=str, default='./logs',
67
+ help='path to ReferIt splits data folder')
68
+
69
+ parser.add_argument('--fusion_dim', default=768, type=int,
70
+ help='fusion module embedding dimensions')
71
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
72
+ help='path to latest checkpoint (default: none)')
73
+ parser.add_argument('--pretrain', default='', type=str, metavar='PATH',
74
+ help='pretrain support load state_dict that are not identical, while have no loss saved as resume')
75
+ parser.add_argument('--print_freq', '-p', default=100, type=int,
76
+ metavar='N', help='print frequency (default: 1e3)')
77
+ parser.add_argument('--savename', default='default', type=str, help='Name head for saved model')
78
+
79
+ parser.add_argument('--seg_thresh', default=0.35, type=float, help='seg score above this value means foreground')
80
+ parser.add_argument('--seg_out_stride', default=2, type=int, help='the seg out stride')
81
+ parser.add_argument('--best_iou', default=-float('Inf'), type=int, help='the best accu')
82
+
83
+
84
+ global args, anchors_full, writer, logger
85
+ args = parser.parse_args()
86
+ args.gsize = 32
87
+ args.date = datetime.datetime.now().strftime('%Y%m%d')
88
+ if args.savename=='default':
89
+ args.savename = 'model_v1_%s_batch%d_%s'%(args.dataset, args.batch_size, args.date)
90
+ os.makedirs(args.log_dir, exist_ok=True)
91
+ args.lr = args.lr * (args.batch_size * args.ngpu // 16)
92
+
93
+ print('----------------------------------------------------------------------')
94
+ print(sys.argv[0])
95
+ print(args)
96
+ print('----------------------------------------------------------------------')
97
+
98
+ return args
99
+
100
+ def main(args):
101
+ os.environ['MASTER_ADDR'] = 'localhost'
102
+ # os.environ['MASTER_PORT'] = '12367'
103
+
104
+ if(torch.cuda.is_available()):
105
+ n_gpus = torch.cuda.device_count()
106
+ print("Running DDP with {} GPUs".format(n_gpus))
107
+ mp.spawn(run, nprocs=n_gpus, args=(n_gpus, args,))
108
+ else:
109
+ print("Please use GPU for training")
110
+
111
+ def run(rank, n_gpus, args):
112
+ dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
113
+ torch.cuda.set_device(rank)
114
+
115
+ ## fix seed
116
+ cudnn.benchmark = False
117
+ cudnn.deterministic = True
118
+ random.seed(args.seed)
119
+ np.random.seed(args.seed+1)
120
+ torch.manual_seed(args.seed+2)
121
+ torch.cuda.manual_seed_all(args.seed+3)
122
+
123
+ ## save logs
124
+ logger = setup_logger(output=os.path.join(args.log_dir, args.savename), distributed_rank=rank, color=False, name="model-v1")
125
+ logger.info(str(sys.argv))
126
+ logger.info(str(args))
127
+
128
+ input_transform = Compose([
129
+ ToTensor(),
130
+ Normalize(
131
+ mean=[0.48145466, 0.4578275, 0.40821073],
132
+ std=[0.26862954, 0.26130258, 0.27577711]
133
+ )
134
+ ])
135
+
136
+
137
+ val_dataset = ReferDataset(data_root=args.data_root,
138
+ dataset=args.dataset,
139
+ split_root=args.split_root,
140
+ split='val',
141
+ imsize = args.size,
142
+ transform=input_transform,
143
+ max_query_len=args.time)
144
+
145
+
146
+ val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False,
147
+ pin_memory=True, drop_last=True, num_workers=args.workers)
148
+
149
+ if args.dataset == 'refcocog_u':
150
+ test_dataset = ReferDataset(data_root=args.data_root,
151
+ dataset=args.dataset,
152
+ split_root=args.split_root,
153
+ split='test',
154
+ imsize = args.size,
155
+ transform=input_transform,
156
+ max_query_len=args.time)
157
+
158
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False,
159
+ pin_memory=True, drop_last=True, num_workers=args.workers)
160
+ elif args.dataset == 'refcocog_g':
161
+ pass
162
+ else:
163
+ testA_dataset = ReferDataset(data_root=args.data_root,
164
+ dataset=args.dataset,
165
+ split_root=args.split_root,
166
+ split='testA',
167
+ imsize = args.size,
168
+ transform=input_transform,
169
+ max_query_len=args.time)
170
+ testB_dataset = ReferDataset(data_root=args.data_root,
171
+ dataset=args.dataset,
172
+ split_root=args.split_root,
173
+ split='testB',
174
+ imsize = args.size,
175
+ transform=input_transform,
176
+ max_query_len=args.time)
177
+
178
+
179
+ testA_loader = DataLoader(testA_dataset, batch_size=1, shuffle=False,
180
+ pin_memory=True, drop_last=True, num_workers=args.workers)
181
+ testB_loader = DataLoader(testB_dataset, batch_size=1, shuffle=False,
182
+ pin_memory=True, drop_last=True, num_workers=args.workers)
183
+
184
+
185
+ ## Model
186
+ model = Model(clip_model=args.clip_model, tunelang=args.tunelang, num_query=args.num_query, fusion_dim=args.fusion_dim).cuda(rank)
187
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
188
+ model_without_ddp = model.module
189
+
190
+ args.start_epoch = 0
191
+ if args.pretrain and os.path.isfile(args.pretrain):
192
+ model=load_pretrain(model, args, logger, rank)
193
+ model.to(rank)
194
+
195
+ visu_param = [param for name, param in model_without_ddp.named_parameters() if 'visumodel' in name]
196
+ text_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' in name]
197
+ rest_param = [param for name, param in model_without_ddp.named_parameters() if 'textmodel' not in name and 'visumodel' not in name]
198
+
199
+
200
+ ## optimizer; adam default
201
+ if args.tunelang:
202
+ optimizer = torch.optim.Adam([{'params': rest_param, 'lr': args.lr},
203
+ {'params': visu_param, 'lr': args.lr / 10.},
204
+ {'params': text_param, 'lr': args.lr / 10.}])
205
+ else:
206
+ optimizer = torch.optim.Adam([{'params': rest_param},
207
+ {'params': visu_param, 'lr': args.lr / 10.}], lr=args.lr)
208
+
209
+ best_miou_seg = -float('Inf')
210
+ if args.resume:
211
+ model = load_resume(model, optimizer, args, logger, rank)
212
+ model.to(rank)
213
+ best_miou_seg = args.best_iou
214
+ print(best_miou_seg)
215
+
216
+ if args.dataset == 'refcocog_u':
217
+ print('\nTest testing:')
218
+ miou_seg, oiou_seg, prec = validate_epoch(args, test_loader, model, logger, 'test')
219
+ # elif args.dataset == 'refcocog_g':
220
+ # pass
221
+ else:
222
+ print('\nTestA testing:')
223
+ miou_seg, oiou_seg, prec = validate_epoch(args, testA_loader, model, logger, 'testA')
224
+ print('\nTestB testing:')
225
+ miou_seg, oiou_seg, prec = validate_epoch(args, testB_loader, model, logger, 'testB')
226
+ miou_seg, oiou_seg, prec = validate_epoch(args, val_loader, model, logger, 'val')
227
+
228
+ if __name__ == "__main__":
229
+ args = get_args()
230
+ main(args)