dianecy commited on
Commit
e290a7d
·
verified ·
1 Parent(s): 2133c10

Upload folder using huggingface_hub

Browse files
ASDA/engine/__pycache__/engine.cpython-39.pyc ADDED
Binary file (4.17 kB). View file
 
ASDA/engine/__pycache__/engine_gref_sbert.cpython-39.pyc ADDED
Binary file (7.52 kB). View file
 
ASDA/engine/__pycache__/engine_gref_sbert_oiou.cpython-39.pyc ADDED
Binary file (7.52 kB). View file
 
ASDA/engine/__pycache__/engine_oiou.cpython-39.pyc ADDED
Binary file (4.37 kB). View file
 
ASDA/engine/__pycache__/engine_rcc_sbert.cpython-39.pyc ADDED
Binary file (6.33 kB). View file
 
ASDA/engine/engine.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.optim
9
+ from torch.autograd import Variable
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from model.model import *
13
+ from dataset.data_loader import *
14
+ from utils.losses import *
15
+ from utils.parsing_metrics import *
16
+ from utils.utils import *
17
+ from utils.utils import dice_loss, sigmoid_focal_loss
18
+
19
+ use_cuda = torch.cuda.is_available()
20
+ print("use_cuda, ", use_cuda)
21
+
22
+
23
+ def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
24
+ print('train at epoch %d'%epoch)
25
+ batch_time = AverageMeter()
26
+ losses = AverageMeter()
27
+ dice_losses = AverageMeter()
28
+ sigmoid_focal_losses = AverageMeter()
29
+ cos_losses = AverageMeter()
30
+ model.train()
31
+ end = time.time()
32
+
33
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader):
34
+ imgs = imgs.cuda(rank, non_blocking=True)
35
+ word_id = word_id.cuda(rank, non_blocking=True)
36
+ word_mask = word_mask.cuda(rank, non_blocking=True)
37
+ seg_map = seg_map.cuda(rank, non_blocking=True)
38
+ image = Variable(imgs)
39
+ word_id = Variable(word_id)
40
+ word_mask = Variable(word_mask)
41
+ seg_map = Variable(seg_map)
42
+
43
+ with autocast():
44
+ mask_out = model(image, word_id, word_mask)
45
+ loss = 0.
46
+
47
+ mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
48
+ seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
49
+ seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
50
+
51
+ dice_loss_ = dice_loss(mask_out, seg_map)
52
+ sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
53
+
54
+ loss += dice_loss_ + sigmoid_focal_loss_
55
+
56
+ optimizer.zero_grad()
57
+ scaler.scale(loss).backward()
58
+ scaler.step(optimizer)
59
+ scaler.update()
60
+
61
+ losses.update(loss.item(), imgs.size(0))
62
+ dice_losses.update(dice_loss_.item(), imgs.size(0))
63
+ sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0))
64
+ cos_losses.update(seg_iou.mean().item(), imgs.size(0))
65
+
66
+ # measure elapsed time
67
+ batch_time.update(time.time() - end)
68
+ end = time.time()
69
+
70
+ if rank == 0 and batch_idx % args.print_freq == 0:
71
+ print_str = 'Epoch: [{0}][{1}/{2}]\t' \
72
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
73
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
74
+ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
75
+ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
76
+ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
77
+ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
78
+ print(print_str)
79
+ logger.info(print_str)
80
+
81
+ return losses.avg
82
+
83
+ def validate_epoch(args, val_loader, model, logger, mode='val'):
84
+ print('begin test')
85
+ batch_time = AverageMeter()
86
+ miou = AverageMeter()
87
+ miou_seg = AverageMeter()
88
+
89
+ prec=dict()
90
+ thresholds = np.arange(0.5, 1, 0.05)
91
+
92
+ for thresh in thresholds:
93
+ prec[thresh]= AverageMeter()
94
+
95
+ model.eval()
96
+ end = time.time()
97
+ idx = 0
98
+
99
+ t_all = []
100
+
101
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
102
+
103
+ imgs = imgs.cuda(0)
104
+ word_id = word_id.cuda(0)
105
+ word_mask = word_mask.cuda(0)
106
+ seg_map = seg_map.cuda(0)
107
+ image = Variable(imgs)
108
+ word_id = Variable(word_id)
109
+ word_mask = Variable(word_mask)
110
+ seg_map = Variable(seg_map)
111
+
112
+ t1 = time.time()
113
+ with torch.no_grad():
114
+ mask_out = model(image, word_id, word_mask)
115
+ mask_out = mask_out.sigmoid()
116
+
117
+ t2 = time.time()
118
+ t_all.append(t2-t1)
119
+
120
+ ## test: convert pred, gt box to original scale with meta-info
121
+ ih = seg_map.shape[-2]
122
+ iw = seg_map.shape[-1]
123
+ nh = int(ih * ratio)
124
+ nw = int(iw * ratio)
125
+ top, bottom = int(dh[0]), nh + int(dh[0])
126
+ left, right = int(dw[0]), nw + int(dw[0])
127
+ ratio = float(ratio)
128
+ new_shape = (iw, ih)
129
+
130
+ ## revert image for visualization
131
+ seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
132
+ seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
133
+ img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
134
+ img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
135
+
136
+ img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
137
+
138
+ # seg
139
+ mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
140
+ mask_out = cv2.resize(mask_out, (args.size, args.size))
141
+ mask_out_np = mask_out[top:bottom, left:right]
142
+ mask_out_np = cv2.resize(mask_out_np, new_shape)
143
+ seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
144
+ miou_seg.update(seg_iou, imgs.size(0))
145
+ for thresh in thresholds:
146
+ prec[thresh].update(seg_prec[thresh], imgs.size(0))
147
+
148
+ # measure elapsed time
149
+ batch_time.update(time.time() - end)
150
+ end = time.time()
151
+ if batch_idx % 1000 == 0:
152
+ print_str = '[{0}/{1}]\t' \
153
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
154
+ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
155
+ .format( \
156
+ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
157
+ print(print_str)
158
+ logger.info(print_str)
159
+ idx = idx + 1
160
+
161
+ print(miou_seg.avg)
162
+ for thresh in thresholds:
163
+ print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
164
+ logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
165
+ logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
166
+ return miou_seg.avg, prec
167
+
ASDA/engine/engine_gref_sbert.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.optim
9
+ from torch.autograd import Variable
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from model.model_sbert_gref import *
13
+ from dataset.data_loader import *
14
+ from utils.losses import *
15
+ from utils.parsing_metrics import *
16
+ from utils.utils import *
17
+ from utils.utils import dice_loss, sigmoid_focal_loss
18
+
19
+ use_cuda = torch.cuda.is_available()
20
+ print("use_cuda, ", use_cuda)
21
+
22
+
23
+ def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None):
24
+ B_, B_ = emb_distance.shape
25
+ positive_mask = torch.zeros_like(emb_distance)
26
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
27
+
28
+ if B_ < len(verb_mask):
29
+ # If B_ equals to 2*K (double the number of verb phrase)
30
+ for i in range(B_ // 2):
31
+ positive_mask[2 * i, 2 * i + 1] = 1
32
+ positive_mask[2 * i + 1, 2 * i] = 1
33
+ else:
34
+ # Process the case where we have a mix of sentences with and without verbs
35
+ i = 0
36
+ while i < B_:
37
+ if verb_mask[i] == 1:
38
+ positive_mask[i, i + 1] = 1
39
+ positive_mask[i + 1, i] = 1
40
+ i += 2
41
+ else:
42
+ i += 1
43
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
44
+ negative_mask = negative_mask.clone()
45
+
46
+ if rows_to_filter is not None and cols_to_filter is not None :
47
+ for row, col in zip(rows_to_filter, cols_to_filter):
48
+ negative_mask[row * 2, col * 2] = 0
49
+ negative_mask[row * 2, col * 2 + 1] = 0
50
+ negative_mask[row * 2 + 1, col * 2] = 0
51
+ negative_mask[row * 2 + 1, col * 2 + 1] = 0
52
+
53
+ return positive_mask, negative_mask
54
+
55
+
56
+ def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
57
+ _, C, H, W = total_fq.shape
58
+
59
+ # Calculate embeddings
60
+ if verbonly :
61
+ B = total_fq[verb_mask].shape[0]
62
+ emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
63
+ assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
64
+ else :
65
+ emb = torch.mean(total_fq, dim=-1)
66
+
67
+ B_ = emb.shape[0]
68
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
69
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
70
+
71
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
72
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
73
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
74
+
75
+ margin_in_radians = m / 57.2958 # Convert degrees to radians
76
+ theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
77
+ # print("sim_matrix : ", sim_matrix)
78
+ # print("theta_matrix : ", theta_matrix)
79
+
80
+ positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter)
81
+ # print("positive_mask : ", positive_mask)
82
+ # print("negative_mask : ", negative_mask)
83
+ # print("positive_mask requires_grad:", positive_mask.requires_grad,
84
+ # "device:", positive_mask.device, "dtype:", positive_mask.dtype)
85
+ # print("negative_mask requires_grad:", negative_mask.requires_grad,
86
+ # "device:", negative_mask.device, "dtype:", negative_mask.dtype)
87
+
88
+
89
+ theta_with_margin = theta_matrix.clone()
90
+ theta_with_margin[positive_mask.bool()] -= margin_in_radians
91
+ logits = theta_with_margin / tau # Scale with temperature
92
+
93
+ # Compute exp logits for softmax
94
+ exp_logits = torch.exp(logits)
95
+ pos_exp_logits = exp_logits * positive_mask
96
+ pos_exp_logits = pos_exp_logits.sum(dim=-1)
97
+ neg_exp_logits = exp_logits * negative_mask
98
+ neg_exp_logits = neg_exp_logits.sum(dim=-1)
99
+
100
+ total_exp_logits = pos_exp_logits + neg_exp_logits
101
+
102
+ positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
103
+ angular_loss = positive_loss.mean()
104
+ # print("angular_loss : ", angular_loss)
105
+
106
+ return angular_loss, B_
107
+
108
+
109
+ def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
110
+ print('train at epoch %d'%epoch)
111
+ batch_time = AverageMeter()
112
+ losses = AverageMeter()
113
+ dice_losses = AverageMeter()
114
+ sigmoid_focal_losses = AverageMeter()
115
+ cos_losses = AverageMeter()
116
+ model.train()
117
+ end = time.time()
118
+
119
+ # argument for verb-centric radial contrastive loss
120
+ mlw = args.metric_loss_weight
121
+ metric_mode = args.metric_mode
122
+ filter_thres = args.filter_thres
123
+ metric_learning = args.metric_learning
124
+
125
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
126
+ B = imgs.size(0) # Original Batch size
127
+
128
+ hp_word_id = params['hp_word_id']
129
+ hp_word_mask = params['hp_word_mask']
130
+ hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
131
+ pos_type = np.array(params['pos_type'])
132
+
133
+ pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0))
134
+
135
+ # print(hp_bert_embs.shape)
136
+ # print(imgs.shape, word_id.shape, word_mask.shape, seg_map.shape)
137
+
138
+ # hardpos flag outside the model
139
+ verb_masks = []
140
+ cl_masks = []
141
+ images = []
142
+ targets = []
143
+ sentences_ = []
144
+ sentences_masked_ = []
145
+
146
+ for idx in range(len(imgs)) :
147
+ sentences_.append(word_id[idx])
148
+ sentences_masked_.append(word_mask[idx])
149
+ images.append(imgs[idx])
150
+ targets.append(seg_map[idx])
151
+
152
+ # If verb exists, process it
153
+ if pos_mask[idx] :
154
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
155
+ cl_masks.extend([1, 0]) # Only original sentence get marked
156
+ sentences_.append(hp_word_id[idx])
157
+ sentences_masked_.append(hp_word_mask[idx])
158
+ images.append(imgs[idx])
159
+ targets.append(seg_map[idx])
160
+ else:
161
+ verb_masks.append(0)
162
+ cl_masks.append(1)
163
+
164
+ imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \
165
+ torch.stack(images).cuda(rank, non_blocking=True),\
166
+ torch.stack(targets).cuda(rank, non_blocking=True),\
167
+ torch.stack(sentences_).cuda(rank, non_blocking=True),\
168
+ torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\
169
+ torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\
170
+ torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True)
171
+
172
+ image = Variable(imgs)
173
+ word_id = Variable(word_id)
174
+ word_mask = Variable(word_mask)
175
+ seg_map = Variable(seg_map)
176
+ verb_masks = Variable(verb_masks)
177
+ cl_masks = Variable(cl_masks)
178
+
179
+ if hp_bert_embs.numel() > 0 :
180
+ mask = ~torch.all(hp_bert_embs == 0, dim=1)
181
+ hp_bert_embs = hp_bert_embs[mask]
182
+ # print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
183
+ norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
184
+ normed_embs = hp_bert_embs / norms
185
+ cosime_sim = torch.mm(normed_embs, normed_embs.T)
186
+ rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
187
+
188
+ # print(normed_embs, normed_embs.requires_grad, normed_embs.device)
189
+ # print(cosime_sim, cosime_sim.requires_grad, cosime_sim.device)
190
+ # print("rows_to_filter : ", rows_to_filter, rows_to_filter.requires_grad)
191
+ # print("cols_to_filter : ", cols_to_filter, cols_to_filter.requires_grad)
192
+
193
+
194
+
195
+ with autocast():
196
+ mask_out_all, metric_tensors = model(image, word_id, word_mask)
197
+ loss = 0.
198
+
199
+ # get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
200
+ mask_out = mask_out_all[cl_masks]
201
+ seg_map_cl = seg_map[cl_masks]
202
+
203
+ mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
204
+ seg_map_np = seg_map_cl.cpu().numpy() # [bs, 1, 208, 208]
205
+ seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
206
+
207
+ dice_loss_ = dice_loss(mask_out, seg_map_cl)
208
+ sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl)
209
+
210
+ dice_weight, focal_weight = 1.0, 1.0
211
+ loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
212
+
213
+ # get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
214
+ if metric_learning and sum(pos_mask) > 1 :
215
+ metric_weight = mlw
216
+ # NS means number of orig-verb pair where verb phrase exists.
217
+ metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
218
+ loss += metric_weight * metric_loss
219
+
220
+ optimizer.zero_grad()
221
+ scaler.scale(loss).backward()
222
+ scaler.step(optimizer)
223
+ scaler.update()
224
+
225
+ losses.update(loss.item(), B)
226
+ dice_losses.update(dice_loss_.item(), B)
227
+ sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
228
+ cos_losses.update(seg_iou.mean().item(), B)
229
+
230
+ # measure elapsed time
231
+ batch_time.update(time.time() - end)
232
+ end = time.time()
233
+
234
+ if rank == 0 and batch_idx % args.print_freq == 0:
235
+ print_str = 'Epoch: [{0}][{1}/{2}]\t' \
236
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
237
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
238
+ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
239
+ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
240
+ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
241
+ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
242
+ print(print_str)
243
+ logger.info(print_str)
244
+
245
+ return losses.avg
246
+
247
+
248
+
249
+ def validate_epoch(args, val_loader, model, logger, mode='val'):
250
+ print('begin test')
251
+ batch_time = AverageMeter()
252
+ miou = AverageMeter()
253
+ miou_seg = AverageMeter()
254
+
255
+ prec=dict()
256
+ thresholds = np.arange(0.5, 1, 0.05)
257
+
258
+ for thresh in thresholds:
259
+ prec[thresh]= AverageMeter()
260
+
261
+ model.eval()
262
+ end = time.time()
263
+ idx = 0
264
+
265
+ t_all = []
266
+ total_intersection = 0.0
267
+ total_union = 0.0
268
+
269
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
270
+
271
+ imgs = imgs.cuda(0)
272
+ word_id = word_id.cuda(0)
273
+ word_mask = word_mask.cuda(0)
274
+ seg_map = seg_map.cuda(0)
275
+ image = Variable(imgs)
276
+ word_id = Variable(word_id)
277
+ word_mask = Variable(word_mask)
278
+ seg_map = Variable(seg_map)
279
+
280
+ t1 = time.time()
281
+ with torch.no_grad():
282
+ mask_out, _ = model(image, word_id, word_mask)
283
+ mask_out = mask_out.sigmoid()
284
+
285
+ t2 = time.time()
286
+ t_all.append(t2-t1)
287
+
288
+ ## test: convert pred, gt box to original scale with meta-info
289
+ ih = seg_map.shape[-2]
290
+ iw = seg_map.shape[-1]
291
+ nh = int(ih * ratio)
292
+ nw = int(iw * ratio)
293
+ top, bottom = int(dh[0]), nh + int(dh[0])
294
+ left, right = int(dw[0]), nw + int(dw[0])
295
+ ratio = float(ratio)
296
+ new_shape = (iw, ih)
297
+
298
+ ## revert image for visualization
299
+ seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
300
+ seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
301
+ img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
302
+ img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
303
+
304
+ img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
305
+
306
+ # seg
307
+ mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
308
+ mask_out = cv2.resize(mask_out, (args.size, args.size))
309
+ mask_out_np = mask_out[top:bottom, left:right]
310
+ mask_out_np = cv2.resize(mask_out_np, new_shape)
311
+ # seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
312
+ seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
313
+
314
+ miou_seg.update(seg_iou, imgs.size(0))
315
+ total_intersection += inter_sum
316
+ total_union += union_sum
317
+
318
+ for thresh in thresholds:
319
+ prec[thresh].update(seg_prec[thresh], imgs.size(0))
320
+
321
+ # measure elapsed time
322
+ batch_time.update(time.time() - end)
323
+ end = time.time()
324
+ if batch_idx % 1000 == 0:
325
+ print_str = '[{0}/{1}]\t' \
326
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
327
+ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
328
+ .format( \
329
+ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
330
+ print(print_str)
331
+ logger.info(print_str)
332
+ idx = idx + 1
333
+ overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
334
+
335
+ print("Mean IoU:", miou_seg.avg)
336
+ print("Overall IoU:", overall_iou)
337
+ logger.info("Mean IoU: %.4f" % miou_seg.avg)
338
+ logger.info("Overall IoU: %.4f" % overall_iou)
339
+
340
+ for thresh in thresholds:
341
+ print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
342
+ logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
343
+ # logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
344
+
345
+ return miou_seg.avg, prec
346
+
347
+
ASDA/engine/engine_gref_sbert_oiou.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.optim
9
+ from torch.autograd import Variable
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from model.model_sbert_gref import *
13
+ from dataset.data_loader import *
14
+ from utils.losses import *
15
+ from utils.parsing_metrics import *
16
+ from utils.utils import *
17
+ from utils.utils import dice_loss, sigmoid_focal_loss
18
+
19
+ use_cuda = torch.cuda.is_available()
20
+ print("use_cuda, ", use_cuda)
21
+
22
+
23
+ def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None):
24
+ B_, B_ = emb_distance.shape
25
+ positive_mask = torch.zeros_like(emb_distance)
26
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
27
+
28
+ if B_ < len(verb_mask):
29
+ # If B_ equals to 2*K (double the number of verb phrase)
30
+ for i in range(B_ // 2):
31
+ positive_mask[2 * i, 2 * i + 1] = 1
32
+ positive_mask[2 * i + 1, 2 * i] = 1
33
+ else:
34
+ # Process the case where we have a mix of sentences with and without verbs
35
+ i = 0
36
+ while i < B_:
37
+ if verb_mask[i] == 1:
38
+ positive_mask[i, i + 1] = 1
39
+ positive_mask[i + 1, i] = 1
40
+ i += 2
41
+ else:
42
+ i += 1
43
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
44
+ negative_mask = negative_mask.clone()
45
+
46
+ if rows_to_filter is not None and cols_to_filter is not None :
47
+ for row, col in zip(rows_to_filter, cols_to_filter):
48
+ negative_mask[row * 2, col * 2] = 0
49
+ negative_mask[row * 2, col * 2 + 1] = 0
50
+ negative_mask[row * 2 + 1, col * 2] = 0
51
+ negative_mask[row * 2 + 1, col * 2 + 1] = 0
52
+
53
+ return positive_mask, negative_mask
54
+
55
+
56
+ def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
57
+ _, C, H, W = total_fq.shape
58
+
59
+ # Calculate embeddings
60
+ if verbonly :
61
+ B = total_fq[verb_mask].shape[0]
62
+ emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
63
+ assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
64
+ else :
65
+ emb = torch.mean(total_fq, dim=-1)
66
+
67
+ B_ = emb.shape[0]
68
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
69
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
70
+
71
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
72
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
73
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
74
+
75
+ margin_in_radians = m / 57.2958 # Convert degrees to radians
76
+ theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
77
+ # print("sim_matrix : ", sim_matrix)
78
+ # print("theta_matrix : ", theta_matrix)
79
+
80
+ positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter)
81
+ # print("positive_mask : ", positive_mask)
82
+ # print("negative_mask : ", negative_mask)
83
+ # print("positive_mask requires_grad:", positive_mask.requires_grad,
84
+ # "device:", positive_mask.device, "dtype:", positive_mask.dtype)
85
+ # print("negative_mask requires_grad:", negative_mask.requires_grad,
86
+ # "device:", negative_mask.device, "dtype:", negative_mask.dtype)
87
+
88
+
89
+ theta_with_margin = theta_matrix.clone()
90
+ theta_with_margin[positive_mask.bool()] -= margin_in_radians
91
+ logits = theta_with_margin / tau # Scale with temperature
92
+
93
+ # Compute exp logits for softmax
94
+ exp_logits = torch.exp(logits)
95
+ pos_exp_logits = exp_logits * positive_mask
96
+ pos_exp_logits = pos_exp_logits.sum(dim=-1)
97
+ neg_exp_logits = exp_logits * negative_mask
98
+ neg_exp_logits = neg_exp_logits.sum(dim=-1)
99
+
100
+ total_exp_logits = pos_exp_logits + neg_exp_logits
101
+
102
+ positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
103
+ angular_loss = positive_loss.mean()
104
+ # print("angular_loss : ", angular_loss)
105
+
106
+ return angular_loss, B_
107
+
108
+
109
+ def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
110
+ print('train at epoch %d'%epoch)
111
+ batch_time = AverageMeter()
112
+ losses = AverageMeter()
113
+ dice_losses = AverageMeter()
114
+ sigmoid_focal_losses = AverageMeter()
115
+ cos_losses = AverageMeter()
116
+ model.train()
117
+ end = time.time()
118
+
119
+ # argument for verb-centric radial contrastive loss
120
+ mlw = args.metric_loss_weight
121
+ metric_mode = args.metric_mode
122
+ filter_thres = args.filter_thres
123
+ metric_learning = args.metric_learning
124
+
125
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
126
+ B = imgs.size(0) # Original Batch size
127
+
128
+ hp_word_id = params['hp_word_id']
129
+ hp_word_mask = params['hp_word_mask']
130
+ hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
131
+ pos_type = np.array(params['pos_type'])
132
+
133
+ pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0))
134
+
135
+ # print(hp_bert_embs.shape)
136
+ # print(imgs.shape, word_id.shape, word_mask.shape, seg_map.shape)
137
+
138
+ # hardpos flag outside the model
139
+ verb_masks = []
140
+ cl_masks = []
141
+ images = []
142
+ targets = []
143
+ sentences_ = []
144
+ sentences_masked_ = []
145
+
146
+ for idx in range(len(imgs)) :
147
+ sentences_.append(word_id[idx])
148
+ sentences_masked_.append(word_mask[idx])
149
+ images.append(imgs[idx])
150
+ targets.append(seg_map[idx])
151
+
152
+ # If verb exists, process it
153
+ if pos_mask[idx] :
154
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
155
+ cl_masks.extend([1, 0]) # Only original sentence get marked
156
+ sentences_.append(hp_word_id[idx])
157
+ sentences_masked_.append(hp_word_mask[idx])
158
+ images.append(imgs[idx])
159
+ targets.append(seg_map[idx])
160
+ else:
161
+ verb_masks.append(0)
162
+ cl_masks.append(1)
163
+
164
+ imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \
165
+ torch.stack(images).cuda(rank, non_blocking=True),\
166
+ torch.stack(targets).cuda(rank, non_blocking=True),\
167
+ torch.stack(sentences_).cuda(rank, non_blocking=True),\
168
+ torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\
169
+ torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\
170
+ torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True)
171
+
172
+ image = Variable(imgs)
173
+ word_id = Variable(word_id)
174
+ word_mask = Variable(word_mask)
175
+ seg_map = Variable(seg_map)
176
+ verb_masks = Variable(verb_masks)
177
+ cl_masks = Variable(cl_masks)
178
+
179
+ if hp_bert_embs.numel() > 0 :
180
+ mask = ~torch.all(hp_bert_embs == 0, dim=1)
181
+ hp_bert_embs = hp_bert_embs[mask]
182
+ # print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
183
+ norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
184
+ normed_embs = hp_bert_embs / norms
185
+ cosime_sim = torch.mm(normed_embs, normed_embs.T)
186
+ rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
187
+
188
+ # print(normed_embs, normed_embs.requires_grad, normed_embs.device)
189
+ # print(cosime_sim, cosime_sim.requires_grad, cosime_sim.device)
190
+ # print("rows_to_filter : ", rows_to_filter, rows_to_filter.requires_grad)
191
+ # print("cols_to_filter : ", cols_to_filter, cols_to_filter.requires_grad)
192
+
193
+ with autocast():
194
+ mask_out_all, metric_tensors = model(image, word_id, word_mask)
195
+ loss = 0.
196
+
197
+ # get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
198
+ mask_out = mask_out_all[cl_masks]
199
+ seg_map_cl = seg_map[cl_masks]
200
+
201
+ mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
202
+ seg_map_np = seg_map_cl.cpu().numpy() # [bs, 1, 208, 208]
203
+ seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
204
+
205
+ dice_loss_ = dice_loss(mask_out, seg_map_cl)
206
+ sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl)
207
+
208
+ dice_weight, focal_weight = 1.0, 1.0
209
+ loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
210
+
211
+ # get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
212
+ if metric_learning and sum(pos_mask) > 1 :
213
+ metric_weight = mlw
214
+ # NS means number of orig-verb pair where verb phrase exists.
215
+ metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
216
+ loss += metric_weight * metric_loss
217
+
218
+ optimizer.zero_grad()
219
+ scaler.scale(loss).backward()
220
+ scaler.step(optimizer)
221
+ scaler.update()
222
+
223
+ losses.update(loss.item(), B)
224
+ dice_losses.update(dice_loss_.item(), B)
225
+ sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
226
+ cos_losses.update(seg_iou.mean().item(), B)
227
+
228
+ # measure elapsed time
229
+ batch_time.update(time.time() - end)
230
+ end = time.time()
231
+
232
+ if rank == 0 and batch_idx % args.print_freq == 0:
233
+ print_str = 'Epoch: [{0}][{1}/{2}]\t' \
234
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
235
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
236
+ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
237
+ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
238
+ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
239
+ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
240
+ print(print_str)
241
+ logger.info(print_str)
242
+
243
+ return losses.avg
244
+
245
+ def validate_epoch(args, val_loader, model, logger, mode='val'):
246
+ print('begin test')
247
+ batch_time = AverageMeter()
248
+ miou = AverageMeter()
249
+ miou_seg = AverageMeter()
250
+
251
+ prec=dict()
252
+ thresholds = np.arange(0.5, 1, 0.05)
253
+
254
+ for thresh in thresholds:
255
+ prec[thresh]= AverageMeter()
256
+
257
+ model.eval()
258
+ end = time.time()
259
+ idx = 0
260
+
261
+ t_all = []
262
+ total_intersection = 0.0
263
+ total_union = 0.0
264
+
265
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
266
+
267
+ imgs = imgs.cuda(0)
268
+ word_id = word_id.cuda(0)
269
+ word_mask = word_mask.cuda(0)
270
+ seg_map = seg_map.cuda(0)
271
+ image = Variable(imgs)
272
+ word_id = Variable(word_id)
273
+ word_mask = Variable(word_mask)
274
+ seg_map = Variable(seg_map)
275
+
276
+ t1 = time.time()
277
+ with torch.no_grad():
278
+ mask_out, _ = model(image, word_id, word_mask)
279
+ mask_out = mask_out.sigmoid()
280
+
281
+ t2 = time.time()
282
+ t_all.append(t2-t1)
283
+
284
+ ## test: convert pred, gt box to original scale with meta-info
285
+ ih = seg_map.shape[-2]
286
+ iw = seg_map.shape[-1]
287
+ nh = int(ih * ratio)
288
+ nw = int(iw * ratio)
289
+ top, bottom = int(dh[0]), nh + int(dh[0])
290
+ left, right = int(dw[0]), nw + int(dw[0])
291
+ ratio = float(ratio)
292
+ new_shape = (iw, ih)
293
+
294
+ ## revert image for visualization
295
+ seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
296
+ seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
297
+ img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
298
+ img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
299
+
300
+ img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
301
+
302
+ # seg
303
+ mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
304
+ mask_out = cv2.resize(mask_out, (args.size, args.size))
305
+ mask_out_np = mask_out[top:bottom, left:right]
306
+ mask_out_np = cv2.resize(mask_out_np, new_shape)
307
+ # seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
308
+ seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
309
+
310
+ miou_seg.update(seg_iou, imgs.size(0))
311
+ total_intersection += inter_sum
312
+ total_union += union_sum
313
+
314
+ for thresh in thresholds:
315
+ prec[thresh].update(seg_prec[thresh], imgs.size(0))
316
+
317
+ # measure elapsed time
318
+ batch_time.update(time.time() - end)
319
+ end = time.time()
320
+ if batch_idx % 1000 == 0:
321
+ print_str = '[{0}/{1}]\t' \
322
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
323
+ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
324
+ .format( \
325
+ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
326
+ print(print_str)
327
+ logger.info(print_str)
328
+ idx = idx + 1
329
+ overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
330
+
331
+ print("Mean IoU:", miou_seg.avg)
332
+ print("Overall IoU:", overall_iou)
333
+ logger.info("Mean IoU: %.4f" % miou_seg.avg)
334
+ logger.info("Overall IoU: %.4f" % overall_iou)
335
+
336
+ for thresh in thresholds:
337
+ print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
338
+ logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
339
+ # logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
340
+ return miou_seg.avg, overall_iou, prec
ASDA/engine/engine_oiou.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.optim
9
+ from torch.autograd import Variable
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from model.model import *
13
+ from dataset.data_loader import *
14
+ from utils.losses import *
15
+ from utils.parsing_metrics import *
16
+ from utils.utils import *
17
+ from utils.utils import dice_loss, sigmoid_focal_loss
18
+
19
+ use_cuda = torch.cuda.is_available()
20
+ print("use_cuda, ", use_cuda)
21
+
22
+
23
+ def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
24
+ print('train at epoch %d'%epoch)
25
+ batch_time = AverageMeter()
26
+ losses = AverageMeter()
27
+ dice_losses = AverageMeter()
28
+ sigmoid_focal_losses = AverageMeter()
29
+ cos_losses = AverageMeter()
30
+ model.train()
31
+ end = time.time()
32
+
33
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader):
34
+ imgs = imgs.cuda(rank, non_blocking=True)
35
+ word_id = word_id.cuda(rank, non_blocking=True)
36
+ word_mask = word_mask.cuda(rank, non_blocking=True)
37
+ seg_map = seg_map.cuda(rank, non_blocking=True)
38
+ image = Variable(imgs)
39
+ word_id = Variable(word_id)
40
+ word_mask = Variable(word_mask)
41
+ seg_map = Variable(seg_map)
42
+
43
+ with autocast():
44
+ mask_out = model(image, word_id, word_mask)
45
+ loss = 0.
46
+
47
+ mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
48
+ seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
49
+ seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
50
+
51
+ dice_loss_ = dice_loss(mask_out, seg_map)
52
+ sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
53
+
54
+ loss += dice_loss_ + sigmoid_focal_loss_
55
+
56
+ optimizer.zero_grad()
57
+ scaler.scale(loss).backward()
58
+ scaler.step(optimizer)
59
+ scaler.update()
60
+
61
+ losses.update(loss.item(), imgs.size(0))
62
+ dice_losses.update(dice_loss_.item(), imgs.size(0))
63
+ sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0))
64
+ cos_losses.update(seg_iou.mean().item(), imgs.size(0))
65
+
66
+ # measure elapsed time
67
+ batch_time.update(time.time() - end)
68
+ end = time.time()
69
+
70
+ if rank == 0 and batch_idx % args.print_freq == 0:
71
+ print_str = 'Epoch: [{0}][{1}/{2}]\t' \
72
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
73
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
74
+ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
75
+ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
76
+ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
77
+ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
78
+ print(print_str)
79
+ logger.info(print_str)
80
+
81
+ return losses.avg
82
+
83
+ def validate_epoch(args, val_loader, model, logger, mode='val'):
84
+ print('begin test')
85
+ batch_time = AverageMeter()
86
+ miou = AverageMeter()
87
+ miou_seg = AverageMeter()
88
+
89
+ prec=dict()
90
+ thresholds = np.arange(0.5, 1, 0.05)
91
+
92
+ for thresh in thresholds:
93
+ prec[thresh]= AverageMeter()
94
+
95
+ model.eval()
96
+ end = time.time()
97
+ idx = 0
98
+
99
+ t_all = []
100
+ total_intersection = 0.0
101
+ total_union = 0.0
102
+
103
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
104
+
105
+ imgs = imgs.cuda(0)
106
+ word_id = word_id.cuda(0)
107
+ word_mask = word_mask.cuda(0)
108
+ seg_map = seg_map.cuda(0)
109
+ image = Variable(imgs)
110
+ word_id = Variable(word_id)
111
+ word_mask = Variable(word_mask)
112
+ seg_map = Variable(seg_map)
113
+
114
+ t1 = time.time()
115
+ with torch.no_grad():
116
+ mask_out = model(image, word_id, word_mask)
117
+ mask_out = mask_out.sigmoid()
118
+
119
+ t2 = time.time()
120
+ t_all.append(t2-t1)
121
+
122
+ ## test: convert pred, gt box to original scale with meta-info
123
+ ih = seg_map.shape[-2]
124
+ iw = seg_map.shape[-1]
125
+ nh = int(ih * ratio)
126
+ nw = int(iw * ratio)
127
+ top, bottom = int(dh[0]), nh + int(dh[0])
128
+ left, right = int(dw[0]), nw + int(dw[0])
129
+ ratio = float(ratio)
130
+ new_shape = (iw, ih)
131
+
132
+ ## revert image for visualization
133
+ seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
134
+ seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
135
+ img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
136
+ img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
137
+
138
+ img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
139
+
140
+ # seg
141
+ mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
142
+ mask_out = cv2.resize(mask_out, (args.size, args.size))
143
+ mask_out_np = mask_out[top:bottom, left:right]
144
+ mask_out_np = cv2.resize(mask_out_np, new_shape)
145
+ # seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
146
+ seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
147
+
148
+ miou_seg.update(seg_iou, imgs.size(0))
149
+ total_intersection += inter_sum
150
+ total_union += union_sum
151
+
152
+ for thresh in thresholds:
153
+ prec[thresh].update(seg_prec[thresh], imgs.size(0))
154
+
155
+ # measure elapsed time
156
+ batch_time.update(time.time() - end)
157
+ end = time.time()
158
+ if batch_idx % 1000 == 0:
159
+ print_str = '[{0}/{1}]\t' \
160
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
161
+ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
162
+ .format( \
163
+ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
164
+ print(print_str)
165
+ logger.info(print_str)
166
+ idx = idx + 1
167
+ overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
168
+
169
+ print("Mean IoU:", miou_seg.avg)
170
+ print("Overall IoU:", overall_iou)
171
+ logger.info("Mean IoU: %.4f" % miou_seg.avg)
172
+ logger.info("Overall IoU: %.4f" % overall_iou)
173
+
174
+ for thresh in thresholds:
175
+ print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
176
+ logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
177
+ # logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
178
+ return miou_seg.avg, overall_iou, prec
179
+
ASDA/engine/engine_rcc_sbert.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import matplotlib as mpl
3
+ mpl.use('Agg')
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.optim
9
+ from torch.autograd import Variable
10
+ from torch.cuda.amp import autocast as autocast
11
+
12
+ from model.model_sbert_gref import *
13
+ from dataset.data_loader import *
14
+ from utils.losses import *
15
+ from utils.parsing_metrics import *
16
+ from utils.utils import *
17
+ from utils.utils import dice_loss, sigmoid_focal_loss
18
+
19
+ use_cuda = torch.cuda.is_available()
20
+ print("use_cuda, ", use_cuda)
21
+
22
+
23
+ def return_mask(emb_distance, rows_to_filter=None, cols_to_filter=None):
24
+ B_, B_ = emb_distance.shape
25
+ positive_mask = torch.zeros_like(emb_distance)
26
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
27
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
28
+ negative_mask = negative_mask.clone()
29
+
30
+ if rows_to_filter is not None and cols_to_filter is not None :
31
+ for row, col in zip(rows_to_filter, cols_to_filter):
32
+ negative_mask[row , col] = 0
33
+ return positive_mask, negative_mask
34
+
35
+
36
+ def UniAngularLogitContrastLoss(total_fq, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
37
+ _, C, H, W = total_fq.shape
38
+
39
+ B = total_fq.shape[0]
40
+ emb = torch.mean(total_fq, dim=(-1, -2)).reshape(B, C)
41
+
42
+ B_ = emb.shape[0]
43
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
44
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
45
+
46
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
47
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
48
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
49
+
50
+ margin_in_radians = m / 57.2958 # Convert degrees to radians
51
+ theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
52
+ # print("sim_matrix : ", sim_matrix)
53
+ # print("theta_matrix : ", theta_matrix)
54
+ positive_mask, negative_mask = return_mask(sim_matrix, rows_to_filter, cols_to_filter)
55
+
56
+ theta_with_margin = theta_matrix.clone()
57
+ theta_with_margin[positive_mask.bool()] -= margin_in_radians
58
+ logits = theta_with_margin / tau # Scale with temperature
59
+
60
+ # Compute exp logits for softmax
61
+ exp_logits = torch.exp(logits)
62
+ pos_exp_logits = exp_logits * positive_mask
63
+ pos_exp_logits = pos_exp_logits.sum(dim=-1)
64
+ neg_exp_logits = exp_logits * negative_mask
65
+ neg_exp_logits = neg_exp_logits.sum(dim=-1)
66
+
67
+ total_exp_logits = pos_exp_logits + neg_exp_logits
68
+
69
+
70
+ positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
71
+ angular_loss = positive_loss.mean()
72
+
73
+ return angular_loss
74
+
75
+ def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
76
+ print('train at epoch %d'%epoch)
77
+ batch_time = AverageMeter()
78
+ losses = AverageMeter()
79
+ dice_losses = AverageMeter()
80
+ sigmoid_focal_losses = AverageMeter()
81
+ cos_losses = AverageMeter()
82
+ model.train()
83
+ end = time.time()
84
+
85
+ # argument for verb-centric radial contrastive loss
86
+ mlw = args.metric_loss_weight
87
+ metric_mode = args.metric_mode
88
+ filter_thres = args.filter_thres
89
+ metric_learning = args.metric_learning
90
+
91
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
92
+ B = imgs.size(0) # Original Batch size
93
+ hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
94
+
95
+ imgs = imgs.cuda(rank, non_blocking=True)
96
+ word_id = word_id.cuda(rank, non_blocking=True)
97
+ word_mask = word_mask.cuda(rank, non_blocking=True)
98
+ seg_map = seg_map.cuda(rank, non_blocking=True)
99
+ image = Variable(imgs)
100
+ word_id = Variable(word_id)
101
+ word_mask = Variable(word_mask)
102
+ seg_map = Variable(seg_map)
103
+
104
+ if hp_bert_embs.numel() > 0 :
105
+ # print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
106
+ norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
107
+ normed_embs = hp_bert_embs / norms
108
+ cosime_sim = torch.mm(normed_embs, normed_embs.T)
109
+ rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
110
+
111
+
112
+ with autocast():
113
+ mask_out, metric_tensors = model(image, word_id, word_mask)
114
+ loss = 0.
115
+
116
+ # get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
117
+
118
+ mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
119
+ seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
120
+ seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
121
+
122
+ dice_loss_ = dice_loss(mask_out, seg_map)
123
+ sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
124
+
125
+ dice_weight, focal_weight = 1.0, 1.0
126
+ loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
127
+
128
+ # get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
129
+ if metric_learning :
130
+ metric_weight = mlw
131
+ metric_loss = UniAngularLogitContrastLoss(metric_tensors, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
132
+
133
+ loss += metric_weight * metric_loss
134
+
135
+ optimizer.zero_grad()
136
+ scaler.scale(loss).backward()
137
+ scaler.step(optimizer)
138
+ scaler.update()
139
+
140
+ losses.update(loss.item(), B)
141
+ dice_losses.update(dice_loss_.item(), B)
142
+ sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
143
+ cos_losses.update(seg_iou.mean().item(), B)
144
+
145
+ # measure elapsed time
146
+ batch_time.update(time.time() - end)
147
+ end = time.time()
148
+
149
+ if rank == 0 and batch_idx % args.print_freq == 0:
150
+ print_str = 'Epoch: [{0}][{1}/{2}]\t' \
151
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
152
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
153
+ 'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
154
+ 'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
155
+ 'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
156
+ .format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
157
+ print(print_str)
158
+ logger.info(print_str)
159
+
160
+ return losses.avg
161
+
162
+ def validate_epoch(args, val_loader, model, logger, mode='val'):
163
+ print('begin test')
164
+ batch_time = AverageMeter()
165
+ miou = AverageMeter()
166
+ miou_seg = AverageMeter()
167
+
168
+ prec=dict()
169
+ thresholds = np.arange(0.5, 1, 0.05)
170
+
171
+ for thresh in thresholds:
172
+ prec[thresh]= AverageMeter()
173
+
174
+ model.eval()
175
+ end = time.time()
176
+ idx = 0
177
+
178
+ t_all = []
179
+ total_intersection = 0.0
180
+ total_union = 0.0
181
+
182
+ for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
183
+
184
+ imgs = imgs.cuda(0)
185
+ word_id = word_id.cuda(0)
186
+ word_mask = word_mask.cuda(0)
187
+ seg_map = seg_map.cuda(0)
188
+ image = Variable(imgs)
189
+ word_id = Variable(word_id)
190
+ word_mask = Variable(word_mask)
191
+ seg_map = Variable(seg_map)
192
+
193
+ t1 = time.time()
194
+ with torch.no_grad():
195
+ mask_out, _ = model(image, word_id, word_mask)
196
+ mask_out = mask_out.sigmoid()
197
+
198
+ t2 = time.time()
199
+ t_all.append(t2-t1)
200
+
201
+ ## test: convert pred, gt box to original scale with meta-info
202
+ ih = seg_map.shape[-2]
203
+ iw = seg_map.shape[-1]
204
+ nh = int(ih * ratio)
205
+ nw = int(iw * ratio)
206
+ top, bottom = int(dh[0]), nh + int(dh[0])
207
+ left, right = int(dw[0]), nw + int(dw[0])
208
+ ratio = float(ratio)
209
+ new_shape = (iw, ih)
210
+
211
+ ## revert image for visualization
212
+ seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
213
+ seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
214
+ img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
215
+ img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
216
+
217
+ img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
218
+
219
+ # seg
220
+ mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
221
+ mask_out = cv2.resize(mask_out, (args.size, args.size))
222
+ mask_out_np = mask_out[top:bottom, left:right]
223
+ mask_out_np = cv2.resize(mask_out_np, new_shape)
224
+ # seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
225
+ seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
226
+
227
+ miou_seg.update(seg_iou, imgs.size(0))
228
+ total_intersection += inter_sum
229
+ total_union += union_sum
230
+
231
+ for thresh in thresholds:
232
+ prec[thresh].update(seg_prec[thresh], imgs.size(0))
233
+
234
+ # measure elapsed time
235
+ batch_time.update(time.time() - end)
236
+ end = time.time()
237
+ if batch_idx % 1000 == 0:
238
+ print_str = '[{0}/{1}]\t' \
239
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
240
+ 'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
241
+ .format( \
242
+ batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
243
+ print(print_str)
244
+ logger.info(print_str)
245
+ idx = idx + 1
246
+ overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
247
+
248
+ print("Mean IoU:", miou_seg.avg)
249
+ print("Overall IoU:", overall_iou)
250
+ logger.info("Mean IoU: %.4f" % miou_seg.avg)
251
+ logger.info("Overall IoU: %.4f" % overall_iou)
252
+
253
+ for thresh in thresholds:
254
+ print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
255
+ logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
256
+ # logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
257
+ return miou_seg.avg, overall_iou, prec
258
+
ASDA/engine/tmp.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import pdb
8
+ import torch.cuda.amp as amp
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+ import wandb
12
+ from loguru import logger
13
+ from utils.dataset_verbonly import tokenize
14
+ from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
15
+ trainMetricGPU)
16
+
17
+ ## todo : add oIoU metric
18
+ def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
19
+ # torch.autograd.set_detect_anomaly(True)
20
+ batch_time = AverageMeter('Batch', ':2.2f')
21
+ data_time = AverageMeter('Data', ':2.2f')
22
+ lr = AverageMeter('Lr', ':1.6f')
23
+ loss_meter = AverageMeter('Loss', ':2.4f')
24
+ iou_meter = AverageMeter('IoU', ':2.2f')
25
+ pr_meter = AverageMeter('Prec@50', ':2.2f')
26
+ progress = ProgressMeter(
27
+ len(train_loader),
28
+ [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
29
+ prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
30
+
31
+
32
+ model.train()
33
+ time.sleep(2)
34
+ end = time.time()
35
+
36
+ # size_list = [320, 352, 384, 416, 448, 480, 512]
37
+ # idx = np.random.choice(len(size_list))
38
+ # new_size = size_list[idx]
39
+
40
+ for i, (image, text, target, hardpos, params) in enumerate(train_loader):
41
+ data_time.update(time.time() - end)
42
+
43
+ # data
44
+ image = image.cuda(non_blocking=True)
45
+ text = text.cuda(non_blocking=True)
46
+ target = target.cuda(non_blocking=True).unsqueeze(1)
47
+ hardpos = hardpos.cuda(non_blocking=True)
48
+ hp_emb = params['hardpos_emb'].cuda(non_blocking=True)
49
+
50
+ with amp.autocast():
51
+ pred, target, loss = model(image, text, target, hardpos, hp_emb) # , fq, vis, word, state
52
+
53
+
54
+ # backward
55
+ optimizer.zero_grad()
56
+ # scaler.scale(loss).backward()
57
+ scaler.scale(loss).backward()
58
+ # loss.backward()
59
+
60
+ # for name, param in model.named_parameters():
61
+ # if param.grad is not None:
62
+ # if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
63
+ # print(f"Inf/NaN in gradients: {name}")
64
+ # for name, param in model.named_parameters():
65
+ # if param.grad is not None:
66
+ # grad_norm = param.grad.norm()
67
+ # if torch.isnan(grad_norm):
68
+ # print(f"NaN gradient detected in {name}")
69
+
70
+ if args.max_norm:
71
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
72
+
73
+ # optimizer.step()
74
+ # scheduler.step()
75
+ scaler.step(optimizer)
76
+ scaler.update()
77
+
78
+ # metric
79
+ iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
80
+ dist.all_reduce(loss.detach())
81
+ dist.all_reduce(iou)
82
+ dist.all_reduce(pr5)
83
+ loss = loss / dist.get_world_size()
84
+ iou = iou / dist.get_world_size()
85
+ pr5 = pr5 / dist.get_world_size()
86
+
87
+ loss_meter.update(loss.item(), image.size(0))
88
+ iou_meter.update(iou.item(), image.size(0))
89
+ pr_meter.update(pr5.item(), image.size(0))
90
+ lr.update(scheduler.get_last_lr()[-1])
91
+ batch_time.update(time.time() - end)
92
+ end = time.time()
93
+
94
+ # if (i + 1) % args.print_freq == 0:
95
+ # progress.display(i + 1)
96
+ # if dist.get_rank() in [-1, 0]:
97
+ # wandb.log(
98
+ # {
99
+ # "time/batch": batch_time.val,
100
+ # "time/data": data_time.val,
101
+ # "training/lr": lr.val,
102
+ # "training/loss": loss_meter.val,
103
+ # "training/iou": iou_meter.val,
104
+ # "training/prec@50": pr_meter.val,
105
+ # },
106
+ # step=epoch * len(train_loader) + (i + 1))
107
+
108
+
109
+ @torch.no_grad()
110
+ def validate(val_loader, model, epoch, args):
111
+ iou_list = []
112
+ I_list = []
113
+ U_list = []
114
+ model.eval()
115
+ time.sleep(2)
116
+ for imgs, texts, masks, param in val_loader:
117
+ # data
118
+ imgs = imgs.cuda(non_blocking=True)
119
+ texts = texts.cuda(non_blocking=True)
120
+ # inference
121
+ preds = model(imgs, texts)
122
+ preds = torch.sigmoid(preds)
123
+ if preds.shape[-2:] != imgs.shape[-2:]:
124
+ preds = F.interpolate(preds,
125
+ size=imgs.shape[-2:],
126
+ mode='bicubic',
127
+ align_corners=True).squeeze(1)
128
+ # process one batch
129
+ # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
130
+ # param['inverse'],
131
+ # param['ori_size']):
132
+ # h, w = np.array(ori_size)
133
+ # mat = np.array(mat)
134
+ # pred = pred.cpu().numpy()
135
+ # pred = cv2.warpAffine(pred, mat, (w, h),
136
+ # flags=cv2.INTER_CUBIC,
137
+ # borderValue=0.)
138
+ # pred = np.array(pred > 0.35)
139
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
140
+ # mask = mask / 255.
141
+ # # iou
142
+ # inter = np.logical_and(pred, mask)
143
+ # union = np.logical_or(pred, mask)
144
+ # iou = np.sum(inter) / (np.sum(union) + 1e-6)
145
+ # iou_list.append(iou)
146
+ # I_list.append(inter)
147
+ # U_list.append(union)
148
+ for pred, mask in zip(preds, masks):
149
+ # h, w = np.array(ori_size)
150
+ # mat = np.array(mat)
151
+ pred = pred.cpu().numpy()
152
+ # pred = cv2.warpAffine(pred, mat, (w, h),
153
+ # flags=cv2.INTER_CUBIC,
154
+ # borderValue=0.)
155
+ pred = np.array(pred > 0.35)
156
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
157
+ # mask = mask / 255.
158
+ mask = mask.numpy()
159
+ # iou
160
+ inter = np.logical_and(pred, mask)
161
+ union = np.logical_or(pred, mask)
162
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
163
+ I_list.append(inter)
164
+ U_list.append(union)
165
+ iou_list.append(iou)
166
+
167
+ iou_list = np.stack(iou_list)
168
+ iou_list = torch.from_numpy(iou_list).to(imgs.device)
169
+ iou_list = concat_all_gather(iou_list)
170
+
171
+ I_list = np.stack(I_list)
172
+ I_list = torch.from_numpy(I_list).to(imgs.device)
173
+ I_list = concat_all_gather(I_list)
174
+
175
+ U_list = np.stack(U_list)
176
+ U_list = torch.from_numpy(U_list).to(imgs.device)
177
+ U_list = concat_all_gather(U_list)
178
+
179
+ overall_I = I_list.sum().item()
180
+ overall_U = U_list.sum().item()
181
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
182
+
183
+
184
+ prec_list = []
185
+ for thres in torch.arange(0.5, 1.0, 0.1):
186
+ tmp = (iou_list > thres).float().mean()
187
+ prec_list.append(tmp)
188
+ iou = iou_list.mean()
189
+ prec = {}
190
+ temp = ' '
191
+ for i, thres in enumerate(range(5, 10)):
192
+ key = 'Pr@{}'.format(thres * 10)
193
+ value = prec_list[i].item()
194
+ prec[key] = value
195
+ temp += "{}: {:.2f} ".format(key, 100. * value)
196
+ head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
197
+ epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
198
+ logger.info(head + temp)
199
+ # print(head)
200
+
201
+ # return three results : mIoU, oIoU and prec results
202
+ return iou.item(), overall_IoU, prec
203
+
204
+
205
+ @torch.no_grad()
206
+ def inference(test_loader, model, args):
207
+ iou_list = []
208
+ I_list = []
209
+ U_list = []
210
+
211
+ tbar = tqdm(test_loader, desc='Inference:', ncols=100)
212
+ model.eval()
213
+ time.sleep(2)
214
+ for img, mask, param in tbar:
215
+ # data
216
+ # img = img.cuda(non_blocking=True)
217
+ # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
218
+ img = img.cuda(non_blocking=True)
219
+ mask = mask[0].cpu().numpy()
220
+
221
+ # dump image & mask
222
+ if args.visualize:
223
+ seg_id = param['seg_id'][0].cpu().numpy()
224
+ img_name = '{}-img.jpg'.format(seg_id)
225
+ mask_name = '{}-mask.png'.format(seg_id)
226
+ cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
227
+ img=param['ori_img'][0].cpu().numpy())
228
+ cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
229
+ img=mask)
230
+ # multiple sentences
231
+ for sent in param['sents']:
232
+ # mask = mask / 255.
233
+ text = tokenize(sent, args.word_len, True)
234
+ text = text.cuda(non_blocking=True)
235
+ # inference
236
+ pred = model(img, text)
237
+ pred = torch.sigmoid(pred)
238
+ if pred.shape[-2:] != img.shape[-2:]:
239
+ pred = F.interpolate(pred,
240
+ size=img.shape[-2:],
241
+ mode='bicubic',
242
+ align_corners=True).squeeze()
243
+ # process one sentence
244
+ # h, w = param['ori_size'].numpy()[0]
245
+ # mat = param['inverse'].numpy()[0]
246
+ pred = pred.cpu().numpy()
247
+ # pred = cv2.warpAffine(pred, mat, (w, h),
248
+ # flags=cv2.INTER_CUBIC,
249
+ # borderValue=0.)
250
+ pred = np.array(pred > 0.35)
251
+ # iou
252
+ inter = np.logical_and(pred, mask)
253
+ union = np.logical_or(pred, mask)
254
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
255
+ iou_list.append(iou)
256
+ I_list.append(inter)
257
+ U_list.append(union)
258
+ # dump prediction
259
+ if args.visualize:
260
+ pred = np.array(pred*255, dtype=np.uint8)
261
+ sent = "_".join(sent[0].split(" "))
262
+ pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
263
+ cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
264
+ img=pred)
265
+ logger.info('=> Metric Calculation <=')
266
+ iou_list = np.stack(iou_list)
267
+ iou_list = torch.from_numpy(iou_list).to(img.device)
268
+
269
+ I_list = np.stack(I_list)
270
+ I_list = torch.from_numpy(I_list).to(img.device)
271
+ U_list = np.stack(U_list)
272
+ U_list = torch.from_numpy(U_list).to(img.device)
273
+ overall_I = I_list.sum().item()
274
+ overall_U = U_list.sum().item()
275
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
276
+
277
+ prec_list = []
278
+ for thres in torch.arange(0.5, 1.0, 0.1):
279
+ tmp = (iou_list > thres).float().mean()
280
+ prec_list.append(tmp)
281
+ iou = iou_list.mean()
282
+ prec = {}
283
+ for i, thres in enumerate(range(5, 10)):
284
+ key = 'Pr@{}'.format(thres*10)
285
+ value = prec_list[i].item()
286
+ prec[key] = value
287
+ logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
288
+ print('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
289
+ for k, v in prec.items():
290
+ logger.info('{}: {:.2f}.'.format(k, 100.*v))
291
+
292
+ return iou.item(), overall_IoU, prec