Upload folder using huggingface_hub
Browse files- ASDA/engine/__pycache__/engine.cpython-39.pyc +0 -0
- ASDA/engine/__pycache__/engine_gref_sbert.cpython-39.pyc +0 -0
- ASDA/engine/__pycache__/engine_gref_sbert_oiou.cpython-39.pyc +0 -0
- ASDA/engine/__pycache__/engine_oiou.cpython-39.pyc +0 -0
- ASDA/engine/__pycache__/engine_rcc_sbert.cpython-39.pyc +0 -0
- ASDA/engine/engine.py +167 -0
- ASDA/engine/engine_gref_sbert.py +347 -0
- ASDA/engine/engine_gref_sbert_oiou.py +340 -0
- ASDA/engine/engine_oiou.py +179 -0
- ASDA/engine/engine_rcc_sbert.py +258 -0
- ASDA/engine/tmp.py +292 -0
ASDA/engine/__pycache__/engine.cpython-39.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
ASDA/engine/__pycache__/engine_gref_sbert.cpython-39.pyc
ADDED
|
Binary file (7.52 kB). View file
|
|
|
ASDA/engine/__pycache__/engine_gref_sbert_oiou.cpython-39.pyc
ADDED
|
Binary file (7.52 kB). View file
|
|
|
ASDA/engine/__pycache__/engine_oiou.cpython-39.pyc
ADDED
|
Binary file (4.37 kB). View file
|
|
|
ASDA/engine/__pycache__/engine_rcc_sbert.cpython-39.pyc
ADDED
|
Binary file (6.33 kB). View file
|
|
|
ASDA/engine/engine.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.optim
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.cuda.amp import autocast as autocast
|
| 11 |
+
|
| 12 |
+
from model.model import *
|
| 13 |
+
from dataset.data_loader import *
|
| 14 |
+
from utils.losses import *
|
| 15 |
+
from utils.parsing_metrics import *
|
| 16 |
+
from utils.utils import *
|
| 17 |
+
from utils.utils import dice_loss, sigmoid_focal_loss
|
| 18 |
+
|
| 19 |
+
use_cuda = torch.cuda.is_available()
|
| 20 |
+
print("use_cuda, ", use_cuda)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
|
| 24 |
+
print('train at epoch %d'%epoch)
|
| 25 |
+
batch_time = AverageMeter()
|
| 26 |
+
losses = AverageMeter()
|
| 27 |
+
dice_losses = AverageMeter()
|
| 28 |
+
sigmoid_focal_losses = AverageMeter()
|
| 29 |
+
cos_losses = AverageMeter()
|
| 30 |
+
model.train()
|
| 31 |
+
end = time.time()
|
| 32 |
+
|
| 33 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader):
|
| 34 |
+
imgs = imgs.cuda(rank, non_blocking=True)
|
| 35 |
+
word_id = word_id.cuda(rank, non_blocking=True)
|
| 36 |
+
word_mask = word_mask.cuda(rank, non_blocking=True)
|
| 37 |
+
seg_map = seg_map.cuda(rank, non_blocking=True)
|
| 38 |
+
image = Variable(imgs)
|
| 39 |
+
word_id = Variable(word_id)
|
| 40 |
+
word_mask = Variable(word_mask)
|
| 41 |
+
seg_map = Variable(seg_map)
|
| 42 |
+
|
| 43 |
+
with autocast():
|
| 44 |
+
mask_out = model(image, word_id, word_mask)
|
| 45 |
+
loss = 0.
|
| 46 |
+
|
| 47 |
+
mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
|
| 48 |
+
seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
|
| 49 |
+
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
|
| 50 |
+
|
| 51 |
+
dice_loss_ = dice_loss(mask_out, seg_map)
|
| 52 |
+
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
|
| 53 |
+
|
| 54 |
+
loss += dice_loss_ + sigmoid_focal_loss_
|
| 55 |
+
|
| 56 |
+
optimizer.zero_grad()
|
| 57 |
+
scaler.scale(loss).backward()
|
| 58 |
+
scaler.step(optimizer)
|
| 59 |
+
scaler.update()
|
| 60 |
+
|
| 61 |
+
losses.update(loss.item(), imgs.size(0))
|
| 62 |
+
dice_losses.update(dice_loss_.item(), imgs.size(0))
|
| 63 |
+
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0))
|
| 64 |
+
cos_losses.update(seg_iou.mean().item(), imgs.size(0))
|
| 65 |
+
|
| 66 |
+
# measure elapsed time
|
| 67 |
+
batch_time.update(time.time() - end)
|
| 68 |
+
end = time.time()
|
| 69 |
+
|
| 70 |
+
if rank == 0 and batch_idx % args.print_freq == 0:
|
| 71 |
+
print_str = 'Epoch: [{0}][{1}/{2}]\t' \
|
| 72 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 73 |
+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
|
| 74 |
+
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
|
| 75 |
+
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
|
| 76 |
+
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
|
| 77 |
+
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
|
| 78 |
+
print(print_str)
|
| 79 |
+
logger.info(print_str)
|
| 80 |
+
|
| 81 |
+
return losses.avg
|
| 82 |
+
|
| 83 |
+
def validate_epoch(args, val_loader, model, logger, mode='val'):
|
| 84 |
+
print('begin test')
|
| 85 |
+
batch_time = AverageMeter()
|
| 86 |
+
miou = AverageMeter()
|
| 87 |
+
miou_seg = AverageMeter()
|
| 88 |
+
|
| 89 |
+
prec=dict()
|
| 90 |
+
thresholds = np.arange(0.5, 1, 0.05)
|
| 91 |
+
|
| 92 |
+
for thresh in thresholds:
|
| 93 |
+
prec[thresh]= AverageMeter()
|
| 94 |
+
|
| 95 |
+
model.eval()
|
| 96 |
+
end = time.time()
|
| 97 |
+
idx = 0
|
| 98 |
+
|
| 99 |
+
t_all = []
|
| 100 |
+
|
| 101 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
|
| 102 |
+
|
| 103 |
+
imgs = imgs.cuda(0)
|
| 104 |
+
word_id = word_id.cuda(0)
|
| 105 |
+
word_mask = word_mask.cuda(0)
|
| 106 |
+
seg_map = seg_map.cuda(0)
|
| 107 |
+
image = Variable(imgs)
|
| 108 |
+
word_id = Variable(word_id)
|
| 109 |
+
word_mask = Variable(word_mask)
|
| 110 |
+
seg_map = Variable(seg_map)
|
| 111 |
+
|
| 112 |
+
t1 = time.time()
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
mask_out = model(image, word_id, word_mask)
|
| 115 |
+
mask_out = mask_out.sigmoid()
|
| 116 |
+
|
| 117 |
+
t2 = time.time()
|
| 118 |
+
t_all.append(t2-t1)
|
| 119 |
+
|
| 120 |
+
## test: convert pred, gt box to original scale with meta-info
|
| 121 |
+
ih = seg_map.shape[-2]
|
| 122 |
+
iw = seg_map.shape[-1]
|
| 123 |
+
nh = int(ih * ratio)
|
| 124 |
+
nw = int(iw * ratio)
|
| 125 |
+
top, bottom = int(dh[0]), nh + int(dh[0])
|
| 126 |
+
left, right = int(dw[0]), nw + int(dw[0])
|
| 127 |
+
ratio = float(ratio)
|
| 128 |
+
new_shape = (iw, ih)
|
| 129 |
+
|
| 130 |
+
## revert image for visualization
|
| 131 |
+
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
|
| 132 |
+
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 133 |
+
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
|
| 134 |
+
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 135 |
+
|
| 136 |
+
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
|
| 137 |
+
|
| 138 |
+
# seg
|
| 139 |
+
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
|
| 140 |
+
mask_out = cv2.resize(mask_out, (args.size, args.size))
|
| 141 |
+
mask_out_np = mask_out[top:bottom, left:right]
|
| 142 |
+
mask_out_np = cv2.resize(mask_out_np, new_shape)
|
| 143 |
+
seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
|
| 144 |
+
miou_seg.update(seg_iou, imgs.size(0))
|
| 145 |
+
for thresh in thresholds:
|
| 146 |
+
prec[thresh].update(seg_prec[thresh], imgs.size(0))
|
| 147 |
+
|
| 148 |
+
# measure elapsed time
|
| 149 |
+
batch_time.update(time.time() - end)
|
| 150 |
+
end = time.time()
|
| 151 |
+
if batch_idx % 1000 == 0:
|
| 152 |
+
print_str = '[{0}/{1}]\t' \
|
| 153 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 154 |
+
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
|
| 155 |
+
.format( \
|
| 156 |
+
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
|
| 157 |
+
print(print_str)
|
| 158 |
+
logger.info(print_str)
|
| 159 |
+
idx = idx + 1
|
| 160 |
+
|
| 161 |
+
print(miou_seg.avg)
|
| 162 |
+
for thresh in thresholds:
|
| 163 |
+
print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
|
| 164 |
+
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
|
| 165 |
+
logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
|
| 166 |
+
return miou_seg.avg, prec
|
| 167 |
+
|
ASDA/engine/engine_gref_sbert.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.optim
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.cuda.amp import autocast as autocast
|
| 11 |
+
|
| 12 |
+
from model.model_sbert_gref import *
|
| 13 |
+
from dataset.data_loader import *
|
| 14 |
+
from utils.losses import *
|
| 15 |
+
from utils.parsing_metrics import *
|
| 16 |
+
from utils.utils import *
|
| 17 |
+
from utils.utils import dice_loss, sigmoid_focal_loss
|
| 18 |
+
|
| 19 |
+
use_cuda = torch.cuda.is_available()
|
| 20 |
+
print("use_cuda, ", use_cuda)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None):
|
| 24 |
+
B_, B_ = emb_distance.shape
|
| 25 |
+
positive_mask = torch.zeros_like(emb_distance)
|
| 26 |
+
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
|
| 27 |
+
|
| 28 |
+
if B_ < len(verb_mask):
|
| 29 |
+
# If B_ equals to 2*K (double the number of verb phrase)
|
| 30 |
+
for i in range(B_ // 2):
|
| 31 |
+
positive_mask[2 * i, 2 * i + 1] = 1
|
| 32 |
+
positive_mask[2 * i + 1, 2 * i] = 1
|
| 33 |
+
else:
|
| 34 |
+
# Process the case where we have a mix of sentences with and without verbs
|
| 35 |
+
i = 0
|
| 36 |
+
while i < B_:
|
| 37 |
+
if verb_mask[i] == 1:
|
| 38 |
+
positive_mask[i, i + 1] = 1
|
| 39 |
+
positive_mask[i + 1, i] = 1
|
| 40 |
+
i += 2
|
| 41 |
+
else:
|
| 42 |
+
i += 1
|
| 43 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
| 44 |
+
negative_mask = negative_mask.clone()
|
| 45 |
+
|
| 46 |
+
if rows_to_filter is not None and cols_to_filter is not None :
|
| 47 |
+
for row, col in zip(rows_to_filter, cols_to_filter):
|
| 48 |
+
negative_mask[row * 2, col * 2] = 0
|
| 49 |
+
negative_mask[row * 2, col * 2 + 1] = 0
|
| 50 |
+
negative_mask[row * 2 + 1, col * 2] = 0
|
| 51 |
+
negative_mask[row * 2 + 1, col * 2 + 1] = 0
|
| 52 |
+
|
| 53 |
+
return positive_mask, negative_mask
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
|
| 57 |
+
_, C, H, W = total_fq.shape
|
| 58 |
+
|
| 59 |
+
# Calculate embeddings
|
| 60 |
+
if verbonly :
|
| 61 |
+
B = total_fq[verb_mask].shape[0]
|
| 62 |
+
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
|
| 63 |
+
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
|
| 64 |
+
else :
|
| 65 |
+
emb = torch.mean(total_fq, dim=-1)
|
| 66 |
+
|
| 67 |
+
B_ = emb.shape[0]
|
| 68 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
| 69 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
| 70 |
+
|
| 71 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
| 72 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
| 73 |
+
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
|
| 74 |
+
|
| 75 |
+
margin_in_radians = m / 57.2958 # Convert degrees to radians
|
| 76 |
+
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
|
| 77 |
+
# print("sim_matrix : ", sim_matrix)
|
| 78 |
+
# print("theta_matrix : ", theta_matrix)
|
| 79 |
+
|
| 80 |
+
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter)
|
| 81 |
+
# print("positive_mask : ", positive_mask)
|
| 82 |
+
# print("negative_mask : ", negative_mask)
|
| 83 |
+
# print("positive_mask requires_grad:", positive_mask.requires_grad,
|
| 84 |
+
# "device:", positive_mask.device, "dtype:", positive_mask.dtype)
|
| 85 |
+
# print("negative_mask requires_grad:", negative_mask.requires_grad,
|
| 86 |
+
# "device:", negative_mask.device, "dtype:", negative_mask.dtype)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
theta_with_margin = theta_matrix.clone()
|
| 90 |
+
theta_with_margin[positive_mask.bool()] -= margin_in_radians
|
| 91 |
+
logits = theta_with_margin / tau # Scale with temperature
|
| 92 |
+
|
| 93 |
+
# Compute exp logits for softmax
|
| 94 |
+
exp_logits = torch.exp(logits)
|
| 95 |
+
pos_exp_logits = exp_logits * positive_mask
|
| 96 |
+
pos_exp_logits = pos_exp_logits.sum(dim=-1)
|
| 97 |
+
neg_exp_logits = exp_logits * negative_mask
|
| 98 |
+
neg_exp_logits = neg_exp_logits.sum(dim=-1)
|
| 99 |
+
|
| 100 |
+
total_exp_logits = pos_exp_logits + neg_exp_logits
|
| 101 |
+
|
| 102 |
+
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
|
| 103 |
+
angular_loss = positive_loss.mean()
|
| 104 |
+
# print("angular_loss : ", angular_loss)
|
| 105 |
+
|
| 106 |
+
return angular_loss, B_
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
|
| 110 |
+
print('train at epoch %d'%epoch)
|
| 111 |
+
batch_time = AverageMeter()
|
| 112 |
+
losses = AverageMeter()
|
| 113 |
+
dice_losses = AverageMeter()
|
| 114 |
+
sigmoid_focal_losses = AverageMeter()
|
| 115 |
+
cos_losses = AverageMeter()
|
| 116 |
+
model.train()
|
| 117 |
+
end = time.time()
|
| 118 |
+
|
| 119 |
+
# argument for verb-centric radial contrastive loss
|
| 120 |
+
mlw = args.metric_loss_weight
|
| 121 |
+
metric_mode = args.metric_mode
|
| 122 |
+
filter_thres = args.filter_thres
|
| 123 |
+
metric_learning = args.metric_learning
|
| 124 |
+
|
| 125 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
|
| 126 |
+
B = imgs.size(0) # Original Batch size
|
| 127 |
+
|
| 128 |
+
hp_word_id = params['hp_word_id']
|
| 129 |
+
hp_word_mask = params['hp_word_mask']
|
| 130 |
+
hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
|
| 131 |
+
pos_type = np.array(params['pos_type'])
|
| 132 |
+
|
| 133 |
+
pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0))
|
| 134 |
+
|
| 135 |
+
# print(hp_bert_embs.shape)
|
| 136 |
+
# print(imgs.shape, word_id.shape, word_mask.shape, seg_map.shape)
|
| 137 |
+
|
| 138 |
+
# hardpos flag outside the model
|
| 139 |
+
verb_masks = []
|
| 140 |
+
cl_masks = []
|
| 141 |
+
images = []
|
| 142 |
+
targets = []
|
| 143 |
+
sentences_ = []
|
| 144 |
+
sentences_masked_ = []
|
| 145 |
+
|
| 146 |
+
for idx in range(len(imgs)) :
|
| 147 |
+
sentences_.append(word_id[idx])
|
| 148 |
+
sentences_masked_.append(word_mask[idx])
|
| 149 |
+
images.append(imgs[idx])
|
| 150 |
+
targets.append(seg_map[idx])
|
| 151 |
+
|
| 152 |
+
# If verb exists, process it
|
| 153 |
+
if pos_mask[idx] :
|
| 154 |
+
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
|
| 155 |
+
cl_masks.extend([1, 0]) # Only original sentence get marked
|
| 156 |
+
sentences_.append(hp_word_id[idx])
|
| 157 |
+
sentences_masked_.append(hp_word_mask[idx])
|
| 158 |
+
images.append(imgs[idx])
|
| 159 |
+
targets.append(seg_map[idx])
|
| 160 |
+
else:
|
| 161 |
+
verb_masks.append(0)
|
| 162 |
+
cl_masks.append(1)
|
| 163 |
+
|
| 164 |
+
imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \
|
| 165 |
+
torch.stack(images).cuda(rank, non_blocking=True),\
|
| 166 |
+
torch.stack(targets).cuda(rank, non_blocking=True),\
|
| 167 |
+
torch.stack(sentences_).cuda(rank, non_blocking=True),\
|
| 168 |
+
torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\
|
| 169 |
+
torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\
|
| 170 |
+
torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True)
|
| 171 |
+
|
| 172 |
+
image = Variable(imgs)
|
| 173 |
+
word_id = Variable(word_id)
|
| 174 |
+
word_mask = Variable(word_mask)
|
| 175 |
+
seg_map = Variable(seg_map)
|
| 176 |
+
verb_masks = Variable(verb_masks)
|
| 177 |
+
cl_masks = Variable(cl_masks)
|
| 178 |
+
|
| 179 |
+
if hp_bert_embs.numel() > 0 :
|
| 180 |
+
mask = ~torch.all(hp_bert_embs == 0, dim=1)
|
| 181 |
+
hp_bert_embs = hp_bert_embs[mask]
|
| 182 |
+
# print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
|
| 183 |
+
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
|
| 184 |
+
normed_embs = hp_bert_embs / norms
|
| 185 |
+
cosime_sim = torch.mm(normed_embs, normed_embs.T)
|
| 186 |
+
rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
|
| 187 |
+
|
| 188 |
+
# print(normed_embs, normed_embs.requires_grad, normed_embs.device)
|
| 189 |
+
# print(cosime_sim, cosime_sim.requires_grad, cosime_sim.device)
|
| 190 |
+
# print("rows_to_filter : ", rows_to_filter, rows_to_filter.requires_grad)
|
| 191 |
+
# print("cols_to_filter : ", cols_to_filter, cols_to_filter.requires_grad)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
with autocast():
|
| 196 |
+
mask_out_all, metric_tensors = model(image, word_id, word_mask)
|
| 197 |
+
loss = 0.
|
| 198 |
+
|
| 199 |
+
# get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
|
| 200 |
+
mask_out = mask_out_all[cl_masks]
|
| 201 |
+
seg_map_cl = seg_map[cl_masks]
|
| 202 |
+
|
| 203 |
+
mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
|
| 204 |
+
seg_map_np = seg_map_cl.cpu().numpy() # [bs, 1, 208, 208]
|
| 205 |
+
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
|
| 206 |
+
|
| 207 |
+
dice_loss_ = dice_loss(mask_out, seg_map_cl)
|
| 208 |
+
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl)
|
| 209 |
+
|
| 210 |
+
dice_weight, focal_weight = 1.0, 1.0
|
| 211 |
+
loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
|
| 212 |
+
|
| 213 |
+
# get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
|
| 214 |
+
if metric_learning and sum(pos_mask) > 1 :
|
| 215 |
+
metric_weight = mlw
|
| 216 |
+
# NS means number of orig-verb pair where verb phrase exists.
|
| 217 |
+
metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
| 218 |
+
loss += metric_weight * metric_loss
|
| 219 |
+
|
| 220 |
+
optimizer.zero_grad()
|
| 221 |
+
scaler.scale(loss).backward()
|
| 222 |
+
scaler.step(optimizer)
|
| 223 |
+
scaler.update()
|
| 224 |
+
|
| 225 |
+
losses.update(loss.item(), B)
|
| 226 |
+
dice_losses.update(dice_loss_.item(), B)
|
| 227 |
+
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
|
| 228 |
+
cos_losses.update(seg_iou.mean().item(), B)
|
| 229 |
+
|
| 230 |
+
# measure elapsed time
|
| 231 |
+
batch_time.update(time.time() - end)
|
| 232 |
+
end = time.time()
|
| 233 |
+
|
| 234 |
+
if rank == 0 and batch_idx % args.print_freq == 0:
|
| 235 |
+
print_str = 'Epoch: [{0}][{1}/{2}]\t' \
|
| 236 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 237 |
+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
|
| 238 |
+
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
|
| 239 |
+
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
|
| 240 |
+
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
|
| 241 |
+
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
|
| 242 |
+
print(print_str)
|
| 243 |
+
logger.info(print_str)
|
| 244 |
+
|
| 245 |
+
return losses.avg
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def validate_epoch(args, val_loader, model, logger, mode='val'):
|
| 250 |
+
print('begin test')
|
| 251 |
+
batch_time = AverageMeter()
|
| 252 |
+
miou = AverageMeter()
|
| 253 |
+
miou_seg = AverageMeter()
|
| 254 |
+
|
| 255 |
+
prec=dict()
|
| 256 |
+
thresholds = np.arange(0.5, 1, 0.05)
|
| 257 |
+
|
| 258 |
+
for thresh in thresholds:
|
| 259 |
+
prec[thresh]= AverageMeter()
|
| 260 |
+
|
| 261 |
+
model.eval()
|
| 262 |
+
end = time.time()
|
| 263 |
+
idx = 0
|
| 264 |
+
|
| 265 |
+
t_all = []
|
| 266 |
+
total_intersection = 0.0
|
| 267 |
+
total_union = 0.0
|
| 268 |
+
|
| 269 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
|
| 270 |
+
|
| 271 |
+
imgs = imgs.cuda(0)
|
| 272 |
+
word_id = word_id.cuda(0)
|
| 273 |
+
word_mask = word_mask.cuda(0)
|
| 274 |
+
seg_map = seg_map.cuda(0)
|
| 275 |
+
image = Variable(imgs)
|
| 276 |
+
word_id = Variable(word_id)
|
| 277 |
+
word_mask = Variable(word_mask)
|
| 278 |
+
seg_map = Variable(seg_map)
|
| 279 |
+
|
| 280 |
+
t1 = time.time()
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
mask_out, _ = model(image, word_id, word_mask)
|
| 283 |
+
mask_out = mask_out.sigmoid()
|
| 284 |
+
|
| 285 |
+
t2 = time.time()
|
| 286 |
+
t_all.append(t2-t1)
|
| 287 |
+
|
| 288 |
+
## test: convert pred, gt box to original scale with meta-info
|
| 289 |
+
ih = seg_map.shape[-2]
|
| 290 |
+
iw = seg_map.shape[-1]
|
| 291 |
+
nh = int(ih * ratio)
|
| 292 |
+
nw = int(iw * ratio)
|
| 293 |
+
top, bottom = int(dh[0]), nh + int(dh[0])
|
| 294 |
+
left, right = int(dw[0]), nw + int(dw[0])
|
| 295 |
+
ratio = float(ratio)
|
| 296 |
+
new_shape = (iw, ih)
|
| 297 |
+
|
| 298 |
+
## revert image for visualization
|
| 299 |
+
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
|
| 300 |
+
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 301 |
+
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
|
| 302 |
+
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 303 |
+
|
| 304 |
+
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
|
| 305 |
+
|
| 306 |
+
# seg
|
| 307 |
+
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
|
| 308 |
+
mask_out = cv2.resize(mask_out, (args.size, args.size))
|
| 309 |
+
mask_out_np = mask_out[top:bottom, left:right]
|
| 310 |
+
mask_out_np = cv2.resize(mask_out_np, new_shape)
|
| 311 |
+
# seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
|
| 312 |
+
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
|
| 313 |
+
|
| 314 |
+
miou_seg.update(seg_iou, imgs.size(0))
|
| 315 |
+
total_intersection += inter_sum
|
| 316 |
+
total_union += union_sum
|
| 317 |
+
|
| 318 |
+
for thresh in thresholds:
|
| 319 |
+
prec[thresh].update(seg_prec[thresh], imgs.size(0))
|
| 320 |
+
|
| 321 |
+
# measure elapsed time
|
| 322 |
+
batch_time.update(time.time() - end)
|
| 323 |
+
end = time.time()
|
| 324 |
+
if batch_idx % 1000 == 0:
|
| 325 |
+
print_str = '[{0}/{1}]\t' \
|
| 326 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 327 |
+
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
|
| 328 |
+
.format( \
|
| 329 |
+
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
|
| 330 |
+
print(print_str)
|
| 331 |
+
logger.info(print_str)
|
| 332 |
+
idx = idx + 1
|
| 333 |
+
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
|
| 334 |
+
|
| 335 |
+
print("Mean IoU:", miou_seg.avg)
|
| 336 |
+
print("Overall IoU:", overall_iou)
|
| 337 |
+
logger.info("Mean IoU: %.4f" % miou_seg.avg)
|
| 338 |
+
logger.info("Overall IoU: %.4f" % overall_iou)
|
| 339 |
+
|
| 340 |
+
for thresh in thresholds:
|
| 341 |
+
print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
|
| 342 |
+
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
|
| 343 |
+
# logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
|
| 344 |
+
|
| 345 |
+
return miou_seg.avg, prec
|
| 346 |
+
|
| 347 |
+
|
ASDA/engine/engine_gref_sbert_oiou.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.optim
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.cuda.amp import autocast as autocast
|
| 11 |
+
|
| 12 |
+
from model.model_sbert_gref import *
|
| 13 |
+
from dataset.data_loader import *
|
| 14 |
+
from utils.losses import *
|
| 15 |
+
from utils.parsing_metrics import *
|
| 16 |
+
from utils.utils import *
|
| 17 |
+
from utils.utils import dice_loss, sigmoid_focal_loss
|
| 18 |
+
|
| 19 |
+
use_cuda = torch.cuda.is_available()
|
| 20 |
+
print("use_cuda, ", use_cuda)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def return_mask(emb_distance, verb_mask=None, rows_to_filter=None, cols_to_filter=None):
|
| 24 |
+
B_, B_ = emb_distance.shape
|
| 25 |
+
positive_mask = torch.zeros_like(emb_distance)
|
| 26 |
+
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
|
| 27 |
+
|
| 28 |
+
if B_ < len(verb_mask):
|
| 29 |
+
# If B_ equals to 2*K (double the number of verb phrase)
|
| 30 |
+
for i in range(B_ // 2):
|
| 31 |
+
positive_mask[2 * i, 2 * i + 1] = 1
|
| 32 |
+
positive_mask[2 * i + 1, 2 * i] = 1
|
| 33 |
+
else:
|
| 34 |
+
# Process the case where we have a mix of sentences with and without verbs
|
| 35 |
+
i = 0
|
| 36 |
+
while i < B_:
|
| 37 |
+
if verb_mask[i] == 1:
|
| 38 |
+
positive_mask[i, i + 1] = 1
|
| 39 |
+
positive_mask[i + 1, i] = 1
|
| 40 |
+
i += 2
|
| 41 |
+
else:
|
| 42 |
+
i += 1
|
| 43 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
| 44 |
+
negative_mask = negative_mask.clone()
|
| 45 |
+
|
| 46 |
+
if rows_to_filter is not None and cols_to_filter is not None :
|
| 47 |
+
for row, col in zip(rows_to_filter, cols_to_filter):
|
| 48 |
+
negative_mask[row * 2, col * 2] = 0
|
| 49 |
+
negative_mask[row * 2, col * 2 + 1] = 0
|
| 50 |
+
negative_mask[row * 2 + 1, col * 2] = 0
|
| 51 |
+
negative_mask[row * 2 + 1, col * 2 + 1] = 0
|
| 52 |
+
|
| 53 |
+
return positive_mask, negative_mask
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def UniAngularLogitContrastLoss(total_fq, verb_mask, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
|
| 57 |
+
_, C, H, W = total_fq.shape
|
| 58 |
+
|
| 59 |
+
# Calculate embeddings
|
| 60 |
+
if verbonly :
|
| 61 |
+
B = total_fq[verb_mask].shape[0]
|
| 62 |
+
emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C)
|
| 63 |
+
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
|
| 64 |
+
else :
|
| 65 |
+
emb = torch.mean(total_fq, dim=-1)
|
| 66 |
+
|
| 67 |
+
B_ = emb.shape[0]
|
| 68 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
| 69 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
| 70 |
+
|
| 71 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
| 72 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
| 73 |
+
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
|
| 74 |
+
|
| 75 |
+
margin_in_radians = m / 57.2958 # Convert degrees to radians
|
| 76 |
+
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
|
| 77 |
+
# print("sim_matrix : ", sim_matrix)
|
| 78 |
+
# print("theta_matrix : ", theta_matrix)
|
| 79 |
+
|
| 80 |
+
positive_mask, negative_mask = return_mask(sim_matrix, verb_mask, rows_to_filter, cols_to_filter)
|
| 81 |
+
# print("positive_mask : ", positive_mask)
|
| 82 |
+
# print("negative_mask : ", negative_mask)
|
| 83 |
+
# print("positive_mask requires_grad:", positive_mask.requires_grad,
|
| 84 |
+
# "device:", positive_mask.device, "dtype:", positive_mask.dtype)
|
| 85 |
+
# print("negative_mask requires_grad:", negative_mask.requires_grad,
|
| 86 |
+
# "device:", negative_mask.device, "dtype:", negative_mask.dtype)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
theta_with_margin = theta_matrix.clone()
|
| 90 |
+
theta_with_margin[positive_mask.bool()] -= margin_in_radians
|
| 91 |
+
logits = theta_with_margin / tau # Scale with temperature
|
| 92 |
+
|
| 93 |
+
# Compute exp logits for softmax
|
| 94 |
+
exp_logits = torch.exp(logits)
|
| 95 |
+
pos_exp_logits = exp_logits * positive_mask
|
| 96 |
+
pos_exp_logits = pos_exp_logits.sum(dim=-1)
|
| 97 |
+
neg_exp_logits = exp_logits * negative_mask
|
| 98 |
+
neg_exp_logits = neg_exp_logits.sum(dim=-1)
|
| 99 |
+
|
| 100 |
+
total_exp_logits = pos_exp_logits + neg_exp_logits
|
| 101 |
+
|
| 102 |
+
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
|
| 103 |
+
angular_loss = positive_loss.mean()
|
| 104 |
+
# print("angular_loss : ", angular_loss)
|
| 105 |
+
|
| 106 |
+
return angular_loss, B_
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
|
| 110 |
+
print('train at epoch %d'%epoch)
|
| 111 |
+
batch_time = AverageMeter()
|
| 112 |
+
losses = AverageMeter()
|
| 113 |
+
dice_losses = AverageMeter()
|
| 114 |
+
sigmoid_focal_losses = AverageMeter()
|
| 115 |
+
cos_losses = AverageMeter()
|
| 116 |
+
model.train()
|
| 117 |
+
end = time.time()
|
| 118 |
+
|
| 119 |
+
# argument for verb-centric radial contrastive loss
|
| 120 |
+
mlw = args.metric_loss_weight
|
| 121 |
+
metric_mode = args.metric_mode
|
| 122 |
+
filter_thres = args.filter_thres
|
| 123 |
+
metric_learning = args.metric_learning
|
| 124 |
+
|
| 125 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
|
| 126 |
+
B = imgs.size(0) # Original Batch size
|
| 127 |
+
|
| 128 |
+
hp_word_id = params['hp_word_id']
|
| 129 |
+
hp_word_mask = params['hp_word_mask']
|
| 130 |
+
hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
|
| 131 |
+
pos_type = np.array(params['pos_type'])
|
| 132 |
+
|
| 133 |
+
pos_mask = torch.tensor(np.where(pos_type == 'hardpos', 1, 0))
|
| 134 |
+
|
| 135 |
+
# print(hp_bert_embs.shape)
|
| 136 |
+
# print(imgs.shape, word_id.shape, word_mask.shape, seg_map.shape)
|
| 137 |
+
|
| 138 |
+
# hardpos flag outside the model
|
| 139 |
+
verb_masks = []
|
| 140 |
+
cl_masks = []
|
| 141 |
+
images = []
|
| 142 |
+
targets = []
|
| 143 |
+
sentences_ = []
|
| 144 |
+
sentences_masked_ = []
|
| 145 |
+
|
| 146 |
+
for idx in range(len(imgs)) :
|
| 147 |
+
sentences_.append(word_id[idx])
|
| 148 |
+
sentences_masked_.append(word_mask[idx])
|
| 149 |
+
images.append(imgs[idx])
|
| 150 |
+
targets.append(seg_map[idx])
|
| 151 |
+
|
| 152 |
+
# If verb exists, process it
|
| 153 |
+
if pos_mask[idx] :
|
| 154 |
+
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
|
| 155 |
+
cl_masks.extend([1, 0]) # Only original sentence get marked
|
| 156 |
+
sentences_.append(hp_word_id[idx])
|
| 157 |
+
sentences_masked_.append(hp_word_mask[idx])
|
| 158 |
+
images.append(imgs[idx])
|
| 159 |
+
targets.append(seg_map[idx])
|
| 160 |
+
else:
|
| 161 |
+
verb_masks.append(0)
|
| 162 |
+
cl_masks.append(1)
|
| 163 |
+
|
| 164 |
+
imgs, seg_map, word_id, word_mask, verb_masks, cl_masks = \
|
| 165 |
+
torch.stack(images).cuda(rank, non_blocking=True),\
|
| 166 |
+
torch.stack(targets).cuda(rank, non_blocking=True),\
|
| 167 |
+
torch.stack(sentences_).cuda(rank, non_blocking=True),\
|
| 168 |
+
torch.stack(sentences_masked_).cuda(rank, non_blocking=True),\
|
| 169 |
+
torch.tensor(verb_masks, dtype=torch.bool).cuda(rank, non_blocking=True),\
|
| 170 |
+
torch.tensor(cl_masks, dtype=torch.bool).cuda(rank, non_blocking=True)
|
| 171 |
+
|
| 172 |
+
image = Variable(imgs)
|
| 173 |
+
word_id = Variable(word_id)
|
| 174 |
+
word_mask = Variable(word_mask)
|
| 175 |
+
seg_map = Variable(seg_map)
|
| 176 |
+
verb_masks = Variable(verb_masks)
|
| 177 |
+
cl_masks = Variable(cl_masks)
|
| 178 |
+
|
| 179 |
+
if hp_bert_embs.numel() > 0 :
|
| 180 |
+
mask = ~torch.all(hp_bert_embs == 0, dim=1)
|
| 181 |
+
hp_bert_embs = hp_bert_embs[mask]
|
| 182 |
+
# print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
|
| 183 |
+
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
|
| 184 |
+
normed_embs = hp_bert_embs / norms
|
| 185 |
+
cosime_sim = torch.mm(normed_embs, normed_embs.T)
|
| 186 |
+
rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
|
| 187 |
+
|
| 188 |
+
# print(normed_embs, normed_embs.requires_grad, normed_embs.device)
|
| 189 |
+
# print(cosime_sim, cosime_sim.requires_grad, cosime_sim.device)
|
| 190 |
+
# print("rows_to_filter : ", rows_to_filter, rows_to_filter.requires_grad)
|
| 191 |
+
# print("cols_to_filter : ", cols_to_filter, cols_to_filter.requires_grad)
|
| 192 |
+
|
| 193 |
+
with autocast():
|
| 194 |
+
mask_out_all, metric_tensors = model(image, word_id, word_mask)
|
| 195 |
+
loss = 0.
|
| 196 |
+
|
| 197 |
+
# get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
|
| 198 |
+
mask_out = mask_out_all[cl_masks]
|
| 199 |
+
seg_map_cl = seg_map[cl_masks]
|
| 200 |
+
|
| 201 |
+
mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
|
| 202 |
+
seg_map_np = seg_map_cl.cpu().numpy() # [bs, 1, 208, 208]
|
| 203 |
+
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
|
| 204 |
+
|
| 205 |
+
dice_loss_ = dice_loss(mask_out, seg_map_cl)
|
| 206 |
+
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map_cl)
|
| 207 |
+
|
| 208 |
+
dice_weight, focal_weight = 1.0, 1.0
|
| 209 |
+
loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
|
| 210 |
+
|
| 211 |
+
# get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
|
| 212 |
+
if metric_learning and sum(pos_mask) > 1 :
|
| 213 |
+
metric_weight = mlw
|
| 214 |
+
# NS means number of orig-verb pair where verb phrase exists.
|
| 215 |
+
metric_loss, NS = UniAngularLogitContrastLoss(metric_tensors, verb_masks, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
| 216 |
+
loss += metric_weight * metric_loss
|
| 217 |
+
|
| 218 |
+
optimizer.zero_grad()
|
| 219 |
+
scaler.scale(loss).backward()
|
| 220 |
+
scaler.step(optimizer)
|
| 221 |
+
scaler.update()
|
| 222 |
+
|
| 223 |
+
losses.update(loss.item(), B)
|
| 224 |
+
dice_losses.update(dice_loss_.item(), B)
|
| 225 |
+
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
|
| 226 |
+
cos_losses.update(seg_iou.mean().item(), B)
|
| 227 |
+
|
| 228 |
+
# measure elapsed time
|
| 229 |
+
batch_time.update(time.time() - end)
|
| 230 |
+
end = time.time()
|
| 231 |
+
|
| 232 |
+
if rank == 0 and batch_idx % args.print_freq == 0:
|
| 233 |
+
print_str = 'Epoch: [{0}][{1}/{2}]\t' \
|
| 234 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 235 |
+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
|
| 236 |
+
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
|
| 237 |
+
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
|
| 238 |
+
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
|
| 239 |
+
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
|
| 240 |
+
print(print_str)
|
| 241 |
+
logger.info(print_str)
|
| 242 |
+
|
| 243 |
+
return losses.avg
|
| 244 |
+
|
| 245 |
+
def validate_epoch(args, val_loader, model, logger, mode='val'):
|
| 246 |
+
print('begin test')
|
| 247 |
+
batch_time = AverageMeter()
|
| 248 |
+
miou = AverageMeter()
|
| 249 |
+
miou_seg = AverageMeter()
|
| 250 |
+
|
| 251 |
+
prec=dict()
|
| 252 |
+
thresholds = np.arange(0.5, 1, 0.05)
|
| 253 |
+
|
| 254 |
+
for thresh in thresholds:
|
| 255 |
+
prec[thresh]= AverageMeter()
|
| 256 |
+
|
| 257 |
+
model.eval()
|
| 258 |
+
end = time.time()
|
| 259 |
+
idx = 0
|
| 260 |
+
|
| 261 |
+
t_all = []
|
| 262 |
+
total_intersection = 0.0
|
| 263 |
+
total_union = 0.0
|
| 264 |
+
|
| 265 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
|
| 266 |
+
|
| 267 |
+
imgs = imgs.cuda(0)
|
| 268 |
+
word_id = word_id.cuda(0)
|
| 269 |
+
word_mask = word_mask.cuda(0)
|
| 270 |
+
seg_map = seg_map.cuda(0)
|
| 271 |
+
image = Variable(imgs)
|
| 272 |
+
word_id = Variable(word_id)
|
| 273 |
+
word_mask = Variable(word_mask)
|
| 274 |
+
seg_map = Variable(seg_map)
|
| 275 |
+
|
| 276 |
+
t1 = time.time()
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
mask_out, _ = model(image, word_id, word_mask)
|
| 279 |
+
mask_out = mask_out.sigmoid()
|
| 280 |
+
|
| 281 |
+
t2 = time.time()
|
| 282 |
+
t_all.append(t2-t1)
|
| 283 |
+
|
| 284 |
+
## test: convert pred, gt box to original scale with meta-info
|
| 285 |
+
ih = seg_map.shape[-2]
|
| 286 |
+
iw = seg_map.shape[-1]
|
| 287 |
+
nh = int(ih * ratio)
|
| 288 |
+
nw = int(iw * ratio)
|
| 289 |
+
top, bottom = int(dh[0]), nh + int(dh[0])
|
| 290 |
+
left, right = int(dw[0]), nw + int(dw[0])
|
| 291 |
+
ratio = float(ratio)
|
| 292 |
+
new_shape = (iw, ih)
|
| 293 |
+
|
| 294 |
+
## revert image for visualization
|
| 295 |
+
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
|
| 296 |
+
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 297 |
+
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
|
| 298 |
+
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 299 |
+
|
| 300 |
+
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
|
| 301 |
+
|
| 302 |
+
# seg
|
| 303 |
+
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
|
| 304 |
+
mask_out = cv2.resize(mask_out, (args.size, args.size))
|
| 305 |
+
mask_out_np = mask_out[top:bottom, left:right]
|
| 306 |
+
mask_out_np = cv2.resize(mask_out_np, new_shape)
|
| 307 |
+
# seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
|
| 308 |
+
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
|
| 309 |
+
|
| 310 |
+
miou_seg.update(seg_iou, imgs.size(0))
|
| 311 |
+
total_intersection += inter_sum
|
| 312 |
+
total_union += union_sum
|
| 313 |
+
|
| 314 |
+
for thresh in thresholds:
|
| 315 |
+
prec[thresh].update(seg_prec[thresh], imgs.size(0))
|
| 316 |
+
|
| 317 |
+
# measure elapsed time
|
| 318 |
+
batch_time.update(time.time() - end)
|
| 319 |
+
end = time.time()
|
| 320 |
+
if batch_idx % 1000 == 0:
|
| 321 |
+
print_str = '[{0}/{1}]\t' \
|
| 322 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 323 |
+
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
|
| 324 |
+
.format( \
|
| 325 |
+
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
|
| 326 |
+
print(print_str)
|
| 327 |
+
logger.info(print_str)
|
| 328 |
+
idx = idx + 1
|
| 329 |
+
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
|
| 330 |
+
|
| 331 |
+
print("Mean IoU:", miou_seg.avg)
|
| 332 |
+
print("Overall IoU:", overall_iou)
|
| 333 |
+
logger.info("Mean IoU: %.4f" % miou_seg.avg)
|
| 334 |
+
logger.info("Overall IoU: %.4f" % overall_iou)
|
| 335 |
+
|
| 336 |
+
for thresh in thresholds:
|
| 337 |
+
print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
|
| 338 |
+
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
|
| 339 |
+
# logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
|
| 340 |
+
return miou_seg.avg, overall_iou, prec
|
ASDA/engine/engine_oiou.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.optim
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.cuda.amp import autocast as autocast
|
| 11 |
+
|
| 12 |
+
from model.model import *
|
| 13 |
+
from dataset.data_loader import *
|
| 14 |
+
from utils.losses import *
|
| 15 |
+
from utils.parsing_metrics import *
|
| 16 |
+
from utils.utils import *
|
| 17 |
+
from utils.utils import dice_loss, sigmoid_focal_loss
|
| 18 |
+
|
| 19 |
+
use_cuda = torch.cuda.is_available()
|
| 20 |
+
print("use_cuda, ", use_cuda)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
|
| 24 |
+
print('train at epoch %d'%epoch)
|
| 25 |
+
batch_time = AverageMeter()
|
| 26 |
+
losses = AverageMeter()
|
| 27 |
+
dice_losses = AverageMeter()
|
| 28 |
+
sigmoid_focal_losses = AverageMeter()
|
| 29 |
+
cos_losses = AverageMeter()
|
| 30 |
+
model.train()
|
| 31 |
+
end = time.time()
|
| 32 |
+
|
| 33 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map) in enumerate(train_loader):
|
| 34 |
+
imgs = imgs.cuda(rank, non_blocking=True)
|
| 35 |
+
word_id = word_id.cuda(rank, non_blocking=True)
|
| 36 |
+
word_mask = word_mask.cuda(rank, non_blocking=True)
|
| 37 |
+
seg_map = seg_map.cuda(rank, non_blocking=True)
|
| 38 |
+
image = Variable(imgs)
|
| 39 |
+
word_id = Variable(word_id)
|
| 40 |
+
word_mask = Variable(word_mask)
|
| 41 |
+
seg_map = Variable(seg_map)
|
| 42 |
+
|
| 43 |
+
with autocast():
|
| 44 |
+
mask_out = model(image, word_id, word_mask)
|
| 45 |
+
loss = 0.
|
| 46 |
+
|
| 47 |
+
mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
|
| 48 |
+
seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
|
| 49 |
+
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
|
| 50 |
+
|
| 51 |
+
dice_loss_ = dice_loss(mask_out, seg_map)
|
| 52 |
+
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
|
| 53 |
+
|
| 54 |
+
loss += dice_loss_ + sigmoid_focal_loss_
|
| 55 |
+
|
| 56 |
+
optimizer.zero_grad()
|
| 57 |
+
scaler.scale(loss).backward()
|
| 58 |
+
scaler.step(optimizer)
|
| 59 |
+
scaler.update()
|
| 60 |
+
|
| 61 |
+
losses.update(loss.item(), imgs.size(0))
|
| 62 |
+
dice_losses.update(dice_loss_.item(), imgs.size(0))
|
| 63 |
+
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), imgs.size(0))
|
| 64 |
+
cos_losses.update(seg_iou.mean().item(), imgs.size(0))
|
| 65 |
+
|
| 66 |
+
# measure elapsed time
|
| 67 |
+
batch_time.update(time.time() - end)
|
| 68 |
+
end = time.time()
|
| 69 |
+
|
| 70 |
+
if rank == 0 and batch_idx % args.print_freq == 0:
|
| 71 |
+
print_str = 'Epoch: [{0}][{1}/{2}]\t' \
|
| 72 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 73 |
+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
|
| 74 |
+
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
|
| 75 |
+
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
|
| 76 |
+
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
|
| 77 |
+
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
|
| 78 |
+
print(print_str)
|
| 79 |
+
logger.info(print_str)
|
| 80 |
+
|
| 81 |
+
return losses.avg
|
| 82 |
+
|
| 83 |
+
def validate_epoch(args, val_loader, model, logger, mode='val'):
|
| 84 |
+
print('begin test')
|
| 85 |
+
batch_time = AverageMeter()
|
| 86 |
+
miou = AverageMeter()
|
| 87 |
+
miou_seg = AverageMeter()
|
| 88 |
+
|
| 89 |
+
prec=dict()
|
| 90 |
+
thresholds = np.arange(0.5, 1, 0.05)
|
| 91 |
+
|
| 92 |
+
for thresh in thresholds:
|
| 93 |
+
prec[thresh]= AverageMeter()
|
| 94 |
+
|
| 95 |
+
model.eval()
|
| 96 |
+
end = time.time()
|
| 97 |
+
idx = 0
|
| 98 |
+
|
| 99 |
+
t_all = []
|
| 100 |
+
total_intersection = 0.0
|
| 101 |
+
total_union = 0.0
|
| 102 |
+
|
| 103 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
|
| 104 |
+
|
| 105 |
+
imgs = imgs.cuda(0)
|
| 106 |
+
word_id = word_id.cuda(0)
|
| 107 |
+
word_mask = word_mask.cuda(0)
|
| 108 |
+
seg_map = seg_map.cuda(0)
|
| 109 |
+
image = Variable(imgs)
|
| 110 |
+
word_id = Variable(word_id)
|
| 111 |
+
word_mask = Variable(word_mask)
|
| 112 |
+
seg_map = Variable(seg_map)
|
| 113 |
+
|
| 114 |
+
t1 = time.time()
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
mask_out = model(image, word_id, word_mask)
|
| 117 |
+
mask_out = mask_out.sigmoid()
|
| 118 |
+
|
| 119 |
+
t2 = time.time()
|
| 120 |
+
t_all.append(t2-t1)
|
| 121 |
+
|
| 122 |
+
## test: convert pred, gt box to original scale with meta-info
|
| 123 |
+
ih = seg_map.shape[-2]
|
| 124 |
+
iw = seg_map.shape[-1]
|
| 125 |
+
nh = int(ih * ratio)
|
| 126 |
+
nw = int(iw * ratio)
|
| 127 |
+
top, bottom = int(dh[0]), nh + int(dh[0])
|
| 128 |
+
left, right = int(dw[0]), nw + int(dw[0])
|
| 129 |
+
ratio = float(ratio)
|
| 130 |
+
new_shape = (iw, ih)
|
| 131 |
+
|
| 132 |
+
## revert image for visualization
|
| 133 |
+
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
|
| 134 |
+
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 135 |
+
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
|
| 136 |
+
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 137 |
+
|
| 138 |
+
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
|
| 139 |
+
|
| 140 |
+
# seg
|
| 141 |
+
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
|
| 142 |
+
mask_out = cv2.resize(mask_out, (args.size, args.size))
|
| 143 |
+
mask_out_np = mask_out[top:bottom, left:right]
|
| 144 |
+
mask_out_np = cv2.resize(mask_out_np, new_shape)
|
| 145 |
+
# seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
|
| 146 |
+
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
|
| 147 |
+
|
| 148 |
+
miou_seg.update(seg_iou, imgs.size(0))
|
| 149 |
+
total_intersection += inter_sum
|
| 150 |
+
total_union += union_sum
|
| 151 |
+
|
| 152 |
+
for thresh in thresholds:
|
| 153 |
+
prec[thresh].update(seg_prec[thresh], imgs.size(0))
|
| 154 |
+
|
| 155 |
+
# measure elapsed time
|
| 156 |
+
batch_time.update(time.time() - end)
|
| 157 |
+
end = time.time()
|
| 158 |
+
if batch_idx % 1000 == 0:
|
| 159 |
+
print_str = '[{0}/{1}]\t' \
|
| 160 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 161 |
+
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
|
| 162 |
+
.format( \
|
| 163 |
+
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
|
| 164 |
+
print(print_str)
|
| 165 |
+
logger.info(print_str)
|
| 166 |
+
idx = idx + 1
|
| 167 |
+
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
|
| 168 |
+
|
| 169 |
+
print("Mean IoU:", miou_seg.avg)
|
| 170 |
+
print("Overall IoU:", overall_iou)
|
| 171 |
+
logger.info("Mean IoU: %.4f" % miou_seg.avg)
|
| 172 |
+
logger.info("Overall IoU: %.4f" % overall_iou)
|
| 173 |
+
|
| 174 |
+
for thresh in thresholds:
|
| 175 |
+
print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
|
| 176 |
+
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
|
| 177 |
+
# logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
|
| 178 |
+
return miou_seg.avg, overall_iou, prec
|
| 179 |
+
|
ASDA/engine/engine_rcc_sbert.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import matplotlib as mpl
|
| 3 |
+
mpl.use('Agg')
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.optim
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torch.cuda.amp import autocast as autocast
|
| 11 |
+
|
| 12 |
+
from model.model_sbert_gref import *
|
| 13 |
+
from dataset.data_loader import *
|
| 14 |
+
from utils.losses import *
|
| 15 |
+
from utils.parsing_metrics import *
|
| 16 |
+
from utils.utils import *
|
| 17 |
+
from utils.utils import dice_loss, sigmoid_focal_loss
|
| 18 |
+
|
| 19 |
+
use_cuda = torch.cuda.is_available()
|
| 20 |
+
print("use_cuda, ", use_cuda)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def return_mask(emb_distance, rows_to_filter=None, cols_to_filter=None):
|
| 24 |
+
B_, B_ = emb_distance.shape
|
| 25 |
+
positive_mask = torch.zeros_like(emb_distance)
|
| 26 |
+
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
|
| 27 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
| 28 |
+
negative_mask = negative_mask.clone()
|
| 29 |
+
|
| 30 |
+
if rows_to_filter is not None and cols_to_filter is not None :
|
| 31 |
+
for row, col in zip(rows_to_filter, cols_to_filter):
|
| 32 |
+
negative_mask[row , col] = 0
|
| 33 |
+
return positive_mask, negative_mask
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def UniAngularLogitContrastLoss(total_fq, rows_to_filter, cols_to_filter, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
|
| 37 |
+
_, C, H, W = total_fq.shape
|
| 38 |
+
|
| 39 |
+
B = total_fq.shape[0]
|
| 40 |
+
emb = torch.mean(total_fq, dim=(-1, -2)).reshape(B, C)
|
| 41 |
+
|
| 42 |
+
B_ = emb.shape[0]
|
| 43 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
| 44 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
| 45 |
+
|
| 46 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
| 47 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
| 48 |
+
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
|
| 49 |
+
|
| 50 |
+
margin_in_radians = m / 57.2958 # Convert degrees to radians
|
| 51 |
+
theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix)
|
| 52 |
+
# print("sim_matrix : ", sim_matrix)
|
| 53 |
+
# print("theta_matrix : ", theta_matrix)
|
| 54 |
+
positive_mask, negative_mask = return_mask(sim_matrix, rows_to_filter, cols_to_filter)
|
| 55 |
+
|
| 56 |
+
theta_with_margin = theta_matrix.clone()
|
| 57 |
+
theta_with_margin[positive_mask.bool()] -= margin_in_radians
|
| 58 |
+
logits = theta_with_margin / tau # Scale with temperature
|
| 59 |
+
|
| 60 |
+
# Compute exp logits for softmax
|
| 61 |
+
exp_logits = torch.exp(logits)
|
| 62 |
+
pos_exp_logits = exp_logits * positive_mask
|
| 63 |
+
pos_exp_logits = pos_exp_logits.sum(dim=-1)
|
| 64 |
+
neg_exp_logits = exp_logits * negative_mask
|
| 65 |
+
neg_exp_logits = neg_exp_logits.sum(dim=-1)
|
| 66 |
+
|
| 67 |
+
total_exp_logits = pos_exp_logits + neg_exp_logits
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
positive_loss = -torch.log(pos_exp_logits/ total_exp_logits)
|
| 71 |
+
angular_loss = positive_loss.mean()
|
| 72 |
+
|
| 73 |
+
return angular_loss
|
| 74 |
+
|
| 75 |
+
def train_epoch(rank, args, train_loader, model, optimizer, epoch, scaler, logger):
|
| 76 |
+
print('train at epoch %d'%epoch)
|
| 77 |
+
batch_time = AverageMeter()
|
| 78 |
+
losses = AverageMeter()
|
| 79 |
+
dice_losses = AverageMeter()
|
| 80 |
+
sigmoid_focal_losses = AverageMeter()
|
| 81 |
+
cos_losses = AverageMeter()
|
| 82 |
+
model.train()
|
| 83 |
+
end = time.time()
|
| 84 |
+
|
| 85 |
+
# argument for verb-centric radial contrastive loss
|
| 86 |
+
mlw = args.metric_loss_weight
|
| 87 |
+
metric_mode = args.metric_mode
|
| 88 |
+
filter_thres = args.filter_thres
|
| 89 |
+
metric_learning = args.metric_learning
|
| 90 |
+
|
| 91 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, params) in enumerate(train_loader):
|
| 92 |
+
B = imgs.size(0) # Original Batch size
|
| 93 |
+
hp_bert_embs = params['hardpos_emb'].cuda(non_blocking=True).squeeze(1)
|
| 94 |
+
|
| 95 |
+
imgs = imgs.cuda(rank, non_blocking=True)
|
| 96 |
+
word_id = word_id.cuda(rank, non_blocking=True)
|
| 97 |
+
word_mask = word_mask.cuda(rank, non_blocking=True)
|
| 98 |
+
seg_map = seg_map.cuda(rank, non_blocking=True)
|
| 99 |
+
image = Variable(imgs)
|
| 100 |
+
word_id = Variable(word_id)
|
| 101 |
+
word_mask = Variable(word_mask)
|
| 102 |
+
seg_map = Variable(seg_map)
|
| 103 |
+
|
| 104 |
+
if hp_bert_embs.numel() > 0 :
|
| 105 |
+
# print(hp_bert_embs.shape, hp_bert_embs.requires_grad, hp_bert_embs.device)
|
| 106 |
+
norms = torch.norm(hp_bert_embs, dim=-1, keepdim=True)
|
| 107 |
+
normed_embs = hp_bert_embs / norms
|
| 108 |
+
cosime_sim = torch.mm(normed_embs, normed_embs.T)
|
| 109 |
+
rows_to_filter, cols_to_filter = torch.where(cosime_sim > filter_thres)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
with autocast():
|
| 113 |
+
mask_out, metric_tensors = model(image, word_id, word_mask)
|
| 114 |
+
loss = 0.
|
| 115 |
+
|
| 116 |
+
# get mask and seg_map for calculating existing loss function (iou loss, dice loss, sigmoid focal loss)
|
| 117 |
+
|
| 118 |
+
mask_out_np = mask_out.data.cpu().numpy() # [bs, 1, 208, 208]
|
| 119 |
+
seg_map_np = seg_map.cpu().numpy() # [bs, 1, 208, 208]
|
| 120 |
+
seg_iou = cal_seg_iou_loss(seg_map_np, mask_out_np, args.seg_thresh)
|
| 121 |
+
|
| 122 |
+
dice_loss_ = dice_loss(mask_out, seg_map)
|
| 123 |
+
sigmoid_focal_loss_ = sigmoid_focal_loss(mask_out, seg_map)
|
| 124 |
+
|
| 125 |
+
dice_weight, focal_weight = 1.0, 1.0
|
| 126 |
+
loss = (dice_weight * dice_loss_) + (focal_weight * sigmoid_focal_loss_)
|
| 127 |
+
|
| 128 |
+
# get angular contrastive loss, which involves original & verb pharase pairs (only for pairs where hardpos verb phrase exists)
|
| 129 |
+
if metric_learning :
|
| 130 |
+
metric_weight = mlw
|
| 131 |
+
metric_loss = UniAngularLogitContrastLoss(metric_tensors, rows_to_filter, cols_to_filter, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
| 132 |
+
|
| 133 |
+
loss += metric_weight * metric_loss
|
| 134 |
+
|
| 135 |
+
optimizer.zero_grad()
|
| 136 |
+
scaler.scale(loss).backward()
|
| 137 |
+
scaler.step(optimizer)
|
| 138 |
+
scaler.update()
|
| 139 |
+
|
| 140 |
+
losses.update(loss.item(), B)
|
| 141 |
+
dice_losses.update(dice_loss_.item(), B)
|
| 142 |
+
sigmoid_focal_losses.update(sigmoid_focal_loss_.item(), B)
|
| 143 |
+
cos_losses.update(seg_iou.mean().item(), B)
|
| 144 |
+
|
| 145 |
+
# measure elapsed time
|
| 146 |
+
batch_time.update(time.time() - end)
|
| 147 |
+
end = time.time()
|
| 148 |
+
|
| 149 |
+
if rank == 0 and batch_idx % args.print_freq == 0:
|
| 150 |
+
print_str = 'Epoch: [{0}][{1}/{2}]\t' \
|
| 151 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 152 |
+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
|
| 153 |
+
'dice_losses {dice_losses.val:.4f} ({dice_losses.avg:.4f})\t' \
|
| 154 |
+
'sigmoid_focal_losses {sigmoid_focal_losses.val:.4f} ({sigmoid_focal_losses.avg:.4f})\t' \
|
| 155 |
+
'IoU {cos_loss.val:.4f} ({cos_loss.avg:.4f})\t' \
|
| 156 |
+
.format(epoch, batch_idx, len(train_loader), batch_time=batch_time, loss=losses, dice_losses=dice_losses, sigmoid_focal_losses=sigmoid_focal_losses, cos_loss=cos_losses)
|
| 157 |
+
print(print_str)
|
| 158 |
+
logger.info(print_str)
|
| 159 |
+
|
| 160 |
+
return losses.avg
|
| 161 |
+
|
| 162 |
+
def validate_epoch(args, val_loader, model, logger, mode='val'):
|
| 163 |
+
print('begin test')
|
| 164 |
+
batch_time = AverageMeter()
|
| 165 |
+
miou = AverageMeter()
|
| 166 |
+
miou_seg = AverageMeter()
|
| 167 |
+
|
| 168 |
+
prec=dict()
|
| 169 |
+
thresholds = np.arange(0.5, 1, 0.05)
|
| 170 |
+
|
| 171 |
+
for thresh in thresholds:
|
| 172 |
+
prec[thresh]= AverageMeter()
|
| 173 |
+
|
| 174 |
+
model.eval()
|
| 175 |
+
end = time.time()
|
| 176 |
+
idx = 0
|
| 177 |
+
|
| 178 |
+
t_all = []
|
| 179 |
+
total_intersection = 0.0
|
| 180 |
+
total_union = 0.0
|
| 181 |
+
|
| 182 |
+
for batch_idx, (imgs, word_id, word_mask, bbox, seg_map, ratio, dw, dh, im_id, phrase, draw_img) in enumerate(val_loader):
|
| 183 |
+
|
| 184 |
+
imgs = imgs.cuda(0)
|
| 185 |
+
word_id = word_id.cuda(0)
|
| 186 |
+
word_mask = word_mask.cuda(0)
|
| 187 |
+
seg_map = seg_map.cuda(0)
|
| 188 |
+
image = Variable(imgs)
|
| 189 |
+
word_id = Variable(word_id)
|
| 190 |
+
word_mask = Variable(word_mask)
|
| 191 |
+
seg_map = Variable(seg_map)
|
| 192 |
+
|
| 193 |
+
t1 = time.time()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
mask_out, _ = model(image, word_id, word_mask)
|
| 196 |
+
mask_out = mask_out.sigmoid()
|
| 197 |
+
|
| 198 |
+
t2 = time.time()
|
| 199 |
+
t_all.append(t2-t1)
|
| 200 |
+
|
| 201 |
+
## test: convert pred, gt box to original scale with meta-info
|
| 202 |
+
ih = seg_map.shape[-2]
|
| 203 |
+
iw = seg_map.shape[-1]
|
| 204 |
+
nh = int(ih * ratio)
|
| 205 |
+
nw = int(iw * ratio)
|
| 206 |
+
top, bottom = int(dh[0]), nh + int(dh[0])
|
| 207 |
+
left, right = int(dw[0]), nw + int(dw[0])
|
| 208 |
+
ratio = float(ratio)
|
| 209 |
+
new_shape = (iw, ih)
|
| 210 |
+
|
| 211 |
+
## revert image for visualization
|
| 212 |
+
seg_map_np = seg_map[0,:,:,:].data.cpu().numpy().transpose(1,2,0)
|
| 213 |
+
seg_map_np = cv2.resize(seg_map_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 214 |
+
img_np = imgs[0,:,top:bottom,left:right].data.cpu().numpy().transpose(1,2,0)
|
| 215 |
+
img_np = cv2.resize(img_np, new_shape, interpolation=cv2.INTER_CUBIC)
|
| 216 |
+
|
| 217 |
+
img_np = Variable(torch.from_numpy(img_np.transpose(2,0,1)).cuda().unsqueeze(0))
|
| 218 |
+
|
| 219 |
+
# seg
|
| 220 |
+
mask_out = mask_out[0].data.cpu().numpy().transpose(1,2,0)
|
| 221 |
+
mask_out = cv2.resize(mask_out, (args.size, args.size))
|
| 222 |
+
mask_out_np = mask_out[top:bottom, left:right]
|
| 223 |
+
mask_out_np = cv2.resize(mask_out_np, new_shape)
|
| 224 |
+
# seg_iou, seg_prec = cal_seg_iou(seg_map[0].cpu().numpy(), mask_out_np, args.seg_thresh)
|
| 225 |
+
seg_iou, seg_prec, inter_sum, union_sum = cal_seg_iou2(seg_map_np, mask_out_np, args.seg_thresh)
|
| 226 |
+
|
| 227 |
+
miou_seg.update(seg_iou, imgs.size(0))
|
| 228 |
+
total_intersection += inter_sum
|
| 229 |
+
total_union += union_sum
|
| 230 |
+
|
| 231 |
+
for thresh in thresholds:
|
| 232 |
+
prec[thresh].update(seg_prec[thresh], imgs.size(0))
|
| 233 |
+
|
| 234 |
+
# measure elapsed time
|
| 235 |
+
batch_time.update(time.time() - end)
|
| 236 |
+
end = time.time()
|
| 237 |
+
if batch_idx % 1000 == 0:
|
| 238 |
+
print_str = '[{0}/{1}]\t' \
|
| 239 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
| 240 |
+
'seg_iu {seg.val:.4f} ({seg.avg:.4f})\t' \
|
| 241 |
+
.format( \
|
| 242 |
+
batch_idx, len(val_loader), batch_time=batch_time, seg=miou_seg)
|
| 243 |
+
print(print_str)
|
| 244 |
+
logger.info(print_str)
|
| 245 |
+
idx = idx + 1
|
| 246 |
+
overall_iou = (total_intersection + 1e-10) / (total_union + 1e-10)
|
| 247 |
+
|
| 248 |
+
print("Mean IoU:", miou_seg.avg)
|
| 249 |
+
print("Overall IoU:", overall_iou)
|
| 250 |
+
logger.info("Mean IoU: %.4f" % miou_seg.avg)
|
| 251 |
+
logger.info("Overall IoU: %.4f" % overall_iou)
|
| 252 |
+
|
| 253 |
+
for thresh in thresholds:
|
| 254 |
+
print("prec@%f: %f"%(thresh,float(prec[thresh].avg)))
|
| 255 |
+
logger.info("prec@%f:%f"%(thresh,float(prec[thresh].avg)))
|
| 256 |
+
# logger.info("%f,%f"%(float(miou.avg), miou_seg.avg))
|
| 257 |
+
return miou_seg.avg, overall_iou, prec
|
| 258 |
+
|
ASDA/engine/tmp.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import pdb
|
| 8 |
+
import torch.cuda.amp as amp
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import wandb
|
| 12 |
+
from loguru import logger
|
| 13 |
+
from utils.dataset_verbonly import tokenize
|
| 14 |
+
from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
|
| 15 |
+
trainMetricGPU)
|
| 16 |
+
|
| 17 |
+
## todo : add oIoU metric
|
| 18 |
+
def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
|
| 19 |
+
# torch.autograd.set_detect_anomaly(True)
|
| 20 |
+
batch_time = AverageMeter('Batch', ':2.2f')
|
| 21 |
+
data_time = AverageMeter('Data', ':2.2f')
|
| 22 |
+
lr = AverageMeter('Lr', ':1.6f')
|
| 23 |
+
loss_meter = AverageMeter('Loss', ':2.4f')
|
| 24 |
+
iou_meter = AverageMeter('IoU', ':2.2f')
|
| 25 |
+
pr_meter = AverageMeter('Prec@50', ':2.2f')
|
| 26 |
+
progress = ProgressMeter(
|
| 27 |
+
len(train_loader),
|
| 28 |
+
[batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
|
| 29 |
+
prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
model.train()
|
| 33 |
+
time.sleep(2)
|
| 34 |
+
end = time.time()
|
| 35 |
+
|
| 36 |
+
# size_list = [320, 352, 384, 416, 448, 480, 512]
|
| 37 |
+
# idx = np.random.choice(len(size_list))
|
| 38 |
+
# new_size = size_list[idx]
|
| 39 |
+
|
| 40 |
+
for i, (image, text, target, hardpos, params) in enumerate(train_loader):
|
| 41 |
+
data_time.update(time.time() - end)
|
| 42 |
+
|
| 43 |
+
# data
|
| 44 |
+
image = image.cuda(non_blocking=True)
|
| 45 |
+
text = text.cuda(non_blocking=True)
|
| 46 |
+
target = target.cuda(non_blocking=True).unsqueeze(1)
|
| 47 |
+
hardpos = hardpos.cuda(non_blocking=True)
|
| 48 |
+
hp_emb = params['hardpos_emb'].cuda(non_blocking=True)
|
| 49 |
+
|
| 50 |
+
with amp.autocast():
|
| 51 |
+
pred, target, loss = model(image, text, target, hardpos, hp_emb) # , fq, vis, word, state
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# backward
|
| 55 |
+
optimizer.zero_grad()
|
| 56 |
+
# scaler.scale(loss).backward()
|
| 57 |
+
scaler.scale(loss).backward()
|
| 58 |
+
# loss.backward()
|
| 59 |
+
|
| 60 |
+
# for name, param in model.named_parameters():
|
| 61 |
+
# if param.grad is not None:
|
| 62 |
+
# if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
|
| 63 |
+
# print(f"Inf/NaN in gradients: {name}")
|
| 64 |
+
# for name, param in model.named_parameters():
|
| 65 |
+
# if param.grad is not None:
|
| 66 |
+
# grad_norm = param.grad.norm()
|
| 67 |
+
# if torch.isnan(grad_norm):
|
| 68 |
+
# print(f"NaN gradient detected in {name}")
|
| 69 |
+
|
| 70 |
+
if args.max_norm:
|
| 71 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
|
| 72 |
+
|
| 73 |
+
# optimizer.step()
|
| 74 |
+
# scheduler.step()
|
| 75 |
+
scaler.step(optimizer)
|
| 76 |
+
scaler.update()
|
| 77 |
+
|
| 78 |
+
# metric
|
| 79 |
+
iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
|
| 80 |
+
dist.all_reduce(loss.detach())
|
| 81 |
+
dist.all_reduce(iou)
|
| 82 |
+
dist.all_reduce(pr5)
|
| 83 |
+
loss = loss / dist.get_world_size()
|
| 84 |
+
iou = iou / dist.get_world_size()
|
| 85 |
+
pr5 = pr5 / dist.get_world_size()
|
| 86 |
+
|
| 87 |
+
loss_meter.update(loss.item(), image.size(0))
|
| 88 |
+
iou_meter.update(iou.item(), image.size(0))
|
| 89 |
+
pr_meter.update(pr5.item(), image.size(0))
|
| 90 |
+
lr.update(scheduler.get_last_lr()[-1])
|
| 91 |
+
batch_time.update(time.time() - end)
|
| 92 |
+
end = time.time()
|
| 93 |
+
|
| 94 |
+
# if (i + 1) % args.print_freq == 0:
|
| 95 |
+
# progress.display(i + 1)
|
| 96 |
+
# if dist.get_rank() in [-1, 0]:
|
| 97 |
+
# wandb.log(
|
| 98 |
+
# {
|
| 99 |
+
# "time/batch": batch_time.val,
|
| 100 |
+
# "time/data": data_time.val,
|
| 101 |
+
# "training/lr": lr.val,
|
| 102 |
+
# "training/loss": loss_meter.val,
|
| 103 |
+
# "training/iou": iou_meter.val,
|
| 104 |
+
# "training/prec@50": pr_meter.val,
|
| 105 |
+
# },
|
| 106 |
+
# step=epoch * len(train_loader) + (i + 1))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def validate(val_loader, model, epoch, args):
|
| 111 |
+
iou_list = []
|
| 112 |
+
I_list = []
|
| 113 |
+
U_list = []
|
| 114 |
+
model.eval()
|
| 115 |
+
time.sleep(2)
|
| 116 |
+
for imgs, texts, masks, param in val_loader:
|
| 117 |
+
# data
|
| 118 |
+
imgs = imgs.cuda(non_blocking=True)
|
| 119 |
+
texts = texts.cuda(non_blocking=True)
|
| 120 |
+
# inference
|
| 121 |
+
preds = model(imgs, texts)
|
| 122 |
+
preds = torch.sigmoid(preds)
|
| 123 |
+
if preds.shape[-2:] != imgs.shape[-2:]:
|
| 124 |
+
preds = F.interpolate(preds,
|
| 125 |
+
size=imgs.shape[-2:],
|
| 126 |
+
mode='bicubic',
|
| 127 |
+
align_corners=True).squeeze(1)
|
| 128 |
+
# process one batch
|
| 129 |
+
# for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
|
| 130 |
+
# param['inverse'],
|
| 131 |
+
# param['ori_size']):
|
| 132 |
+
# h, w = np.array(ori_size)
|
| 133 |
+
# mat = np.array(mat)
|
| 134 |
+
# pred = pred.cpu().numpy()
|
| 135 |
+
# pred = cv2.warpAffine(pred, mat, (w, h),
|
| 136 |
+
# flags=cv2.INTER_CUBIC,
|
| 137 |
+
# borderValue=0.)
|
| 138 |
+
# pred = np.array(pred > 0.35)
|
| 139 |
+
# mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
|
| 140 |
+
# mask = mask / 255.
|
| 141 |
+
# # iou
|
| 142 |
+
# inter = np.logical_and(pred, mask)
|
| 143 |
+
# union = np.logical_or(pred, mask)
|
| 144 |
+
# iou = np.sum(inter) / (np.sum(union) + 1e-6)
|
| 145 |
+
# iou_list.append(iou)
|
| 146 |
+
# I_list.append(inter)
|
| 147 |
+
# U_list.append(union)
|
| 148 |
+
for pred, mask in zip(preds, masks):
|
| 149 |
+
# h, w = np.array(ori_size)
|
| 150 |
+
# mat = np.array(mat)
|
| 151 |
+
pred = pred.cpu().numpy()
|
| 152 |
+
# pred = cv2.warpAffine(pred, mat, (w, h),
|
| 153 |
+
# flags=cv2.INTER_CUBIC,
|
| 154 |
+
# borderValue=0.)
|
| 155 |
+
pred = np.array(pred > 0.35)
|
| 156 |
+
# mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
|
| 157 |
+
# mask = mask / 255.
|
| 158 |
+
mask = mask.numpy()
|
| 159 |
+
# iou
|
| 160 |
+
inter = np.logical_and(pred, mask)
|
| 161 |
+
union = np.logical_or(pred, mask)
|
| 162 |
+
iou = np.sum(inter) / (np.sum(union) + 1e-6)
|
| 163 |
+
I_list.append(inter)
|
| 164 |
+
U_list.append(union)
|
| 165 |
+
iou_list.append(iou)
|
| 166 |
+
|
| 167 |
+
iou_list = np.stack(iou_list)
|
| 168 |
+
iou_list = torch.from_numpy(iou_list).to(imgs.device)
|
| 169 |
+
iou_list = concat_all_gather(iou_list)
|
| 170 |
+
|
| 171 |
+
I_list = np.stack(I_list)
|
| 172 |
+
I_list = torch.from_numpy(I_list).to(imgs.device)
|
| 173 |
+
I_list = concat_all_gather(I_list)
|
| 174 |
+
|
| 175 |
+
U_list = np.stack(U_list)
|
| 176 |
+
U_list = torch.from_numpy(U_list).to(imgs.device)
|
| 177 |
+
U_list = concat_all_gather(U_list)
|
| 178 |
+
|
| 179 |
+
overall_I = I_list.sum().item()
|
| 180 |
+
overall_U = U_list.sum().item()
|
| 181 |
+
overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
prec_list = []
|
| 185 |
+
for thres in torch.arange(0.5, 1.0, 0.1):
|
| 186 |
+
tmp = (iou_list > thres).float().mean()
|
| 187 |
+
prec_list.append(tmp)
|
| 188 |
+
iou = iou_list.mean()
|
| 189 |
+
prec = {}
|
| 190 |
+
temp = ' '
|
| 191 |
+
for i, thres in enumerate(range(5, 10)):
|
| 192 |
+
key = 'Pr@{}'.format(thres * 10)
|
| 193 |
+
value = prec_list[i].item()
|
| 194 |
+
prec[key] = value
|
| 195 |
+
temp += "{}: {:.2f} ".format(key, 100. * value)
|
| 196 |
+
head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
|
| 197 |
+
epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
|
| 198 |
+
logger.info(head + temp)
|
| 199 |
+
# print(head)
|
| 200 |
+
|
| 201 |
+
# return three results : mIoU, oIoU and prec results
|
| 202 |
+
return iou.item(), overall_IoU, prec
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def inference(test_loader, model, args):
|
| 207 |
+
iou_list = []
|
| 208 |
+
I_list = []
|
| 209 |
+
U_list = []
|
| 210 |
+
|
| 211 |
+
tbar = tqdm(test_loader, desc='Inference:', ncols=100)
|
| 212 |
+
model.eval()
|
| 213 |
+
time.sleep(2)
|
| 214 |
+
for img, mask, param in tbar:
|
| 215 |
+
# data
|
| 216 |
+
# img = img.cuda(non_blocking=True)
|
| 217 |
+
# mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
|
| 218 |
+
img = img.cuda(non_blocking=True)
|
| 219 |
+
mask = mask[0].cpu().numpy()
|
| 220 |
+
|
| 221 |
+
# dump image & mask
|
| 222 |
+
if args.visualize:
|
| 223 |
+
seg_id = param['seg_id'][0].cpu().numpy()
|
| 224 |
+
img_name = '{}-img.jpg'.format(seg_id)
|
| 225 |
+
mask_name = '{}-mask.png'.format(seg_id)
|
| 226 |
+
cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
|
| 227 |
+
img=param['ori_img'][0].cpu().numpy())
|
| 228 |
+
cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
|
| 229 |
+
img=mask)
|
| 230 |
+
# multiple sentences
|
| 231 |
+
for sent in param['sents']:
|
| 232 |
+
# mask = mask / 255.
|
| 233 |
+
text = tokenize(sent, args.word_len, True)
|
| 234 |
+
text = text.cuda(non_blocking=True)
|
| 235 |
+
# inference
|
| 236 |
+
pred = model(img, text)
|
| 237 |
+
pred = torch.sigmoid(pred)
|
| 238 |
+
if pred.shape[-2:] != img.shape[-2:]:
|
| 239 |
+
pred = F.interpolate(pred,
|
| 240 |
+
size=img.shape[-2:],
|
| 241 |
+
mode='bicubic',
|
| 242 |
+
align_corners=True).squeeze()
|
| 243 |
+
# process one sentence
|
| 244 |
+
# h, w = param['ori_size'].numpy()[0]
|
| 245 |
+
# mat = param['inverse'].numpy()[0]
|
| 246 |
+
pred = pred.cpu().numpy()
|
| 247 |
+
# pred = cv2.warpAffine(pred, mat, (w, h),
|
| 248 |
+
# flags=cv2.INTER_CUBIC,
|
| 249 |
+
# borderValue=0.)
|
| 250 |
+
pred = np.array(pred > 0.35)
|
| 251 |
+
# iou
|
| 252 |
+
inter = np.logical_and(pred, mask)
|
| 253 |
+
union = np.logical_or(pred, mask)
|
| 254 |
+
iou = np.sum(inter) / (np.sum(union) + 1e-6)
|
| 255 |
+
iou_list.append(iou)
|
| 256 |
+
I_list.append(inter)
|
| 257 |
+
U_list.append(union)
|
| 258 |
+
# dump prediction
|
| 259 |
+
if args.visualize:
|
| 260 |
+
pred = np.array(pred*255, dtype=np.uint8)
|
| 261 |
+
sent = "_".join(sent[0].split(" "))
|
| 262 |
+
pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
|
| 263 |
+
cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
|
| 264 |
+
img=pred)
|
| 265 |
+
logger.info('=> Metric Calculation <=')
|
| 266 |
+
iou_list = np.stack(iou_list)
|
| 267 |
+
iou_list = torch.from_numpy(iou_list).to(img.device)
|
| 268 |
+
|
| 269 |
+
I_list = np.stack(I_list)
|
| 270 |
+
I_list = torch.from_numpy(I_list).to(img.device)
|
| 271 |
+
U_list = np.stack(U_list)
|
| 272 |
+
U_list = torch.from_numpy(U_list).to(img.device)
|
| 273 |
+
overall_I = I_list.sum().item()
|
| 274 |
+
overall_U = U_list.sum().item()
|
| 275 |
+
overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
|
| 276 |
+
|
| 277 |
+
prec_list = []
|
| 278 |
+
for thres in torch.arange(0.5, 1.0, 0.1):
|
| 279 |
+
tmp = (iou_list > thres).float().mean()
|
| 280 |
+
prec_list.append(tmp)
|
| 281 |
+
iou = iou_list.mean()
|
| 282 |
+
prec = {}
|
| 283 |
+
for i, thres in enumerate(range(5, 10)):
|
| 284 |
+
key = 'Pr@{}'.format(thres*10)
|
| 285 |
+
value = prec_list[i].item()
|
| 286 |
+
prec[key] = value
|
| 287 |
+
logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
|
| 288 |
+
print('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
|
| 289 |
+
for k, v in prec.items():
|
| 290 |
+
logger.info('{}: {:.2f}.'.format(k, 100.*v))
|
| 291 |
+
|
| 292 |
+
return iou.item(), overall_IoU, prec
|