|
|
import torch.nn.functional as F
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from utils import segutils
|
|
|
import core.denseaffinity as dautils
|
|
|
|
|
|
identity_mapping = lambda x, *args, **kwargs: x
|
|
|
|
|
|
|
|
|
class ContrastiveConfig:
|
|
|
def __init__(self, config=None):
|
|
|
|
|
|
if config is None:
|
|
|
self._data = {
|
|
|
'aug': {
|
|
|
'n_transformed_imgs': 2,
|
|
|
'blurkernelsize': [1],
|
|
|
'maxjitter': 0.0,
|
|
|
'maxangle': 0,
|
|
|
|
|
|
'maxscale': 1.0,
|
|
|
'maxshear': 20,
|
|
|
'randomhflip': False,
|
|
|
'apply_affine': True,
|
|
|
'debug': False
|
|
|
},
|
|
|
'model': {
|
|
|
'out_channels': 64,
|
|
|
'kernel_size': 1,
|
|
|
'prepend_relu': False,
|
|
|
'append_normalize': False,
|
|
|
'debug': False
|
|
|
},
|
|
|
'fitting': {
|
|
|
'lr': 1e-2,
|
|
|
'optimizer': torch.optim.SGD,
|
|
|
'num_epochs': 25,
|
|
|
'nce': {
|
|
|
'temperature': 0.5,
|
|
|
'debug': False
|
|
|
},
|
|
|
'normalize_after_fwd_pass': True,
|
|
|
'q_nceloss': True,
|
|
|
's_nceloss': True,
|
|
|
'protoloss': False,
|
|
|
'keepvarloss': True,
|
|
|
'symmetricloss': False,
|
|
|
'selfattentionloss': False,
|
|
|
'o_t_contr_proto_loss': True,
|
|
|
'debug': False
|
|
|
},
|
|
|
'featext': {
|
|
|
'l0': 3,
|
|
|
'fit_every_episode': False
|
|
|
}
|
|
|
}
|
|
|
else:
|
|
|
self._data = config
|
|
|
|
|
|
def __getattr__(self, key):
|
|
|
|
|
|
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
|
|
|
|
|
|
if _data is not None and key in _data:
|
|
|
if isinstance(_data[key], dict):
|
|
|
return ContrastiveConfig(_data[key])
|
|
|
return _data[key]
|
|
|
|
|
|
|
|
|
|
|
|
raise AttributeError(f"No setting named {key}")
|
|
|
|
|
|
def __setattr__(self, key, value):
|
|
|
|
|
|
if key == '_data':
|
|
|
super().__setattr__(key, value)
|
|
|
else:
|
|
|
|
|
|
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
|
|
|
|
|
|
if _data is not None:
|
|
|
_data[key] = value
|
|
|
else:
|
|
|
|
|
|
raise AttributeError("Unexpected")
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
return str(self._data)
|
|
|
|
|
|
|
|
|
def dense_info_nce_loss(original_features, transformed_features, config_nce):
|
|
|
B, C, H, W = transformed_features.shape
|
|
|
o_features = original_features.expand(B, C, H, W).permute(0, 2, 3, 1).view(B, H * W, C)
|
|
|
t_features = transformed_features.permute(0, 2, 3, 1).view(B, H * W, C)
|
|
|
|
|
|
|
|
|
positive_logits = torch.einsum('bik,bik->bi', o_features, t_features) / config_nce.temperature
|
|
|
|
|
|
|
|
|
all_logits = torch.einsum('bik,bjk->bij', o_features, t_features) / config_nce.temperature
|
|
|
|
|
|
if config_nce.debug: print('pos/neg:', positive_logits.mean().detach(), all_logits.mean().detach())
|
|
|
|
|
|
|
|
|
max_logits = torch.max(all_logits, dim=-1, keepdim=True).values
|
|
|
log_sum_exp = max_logits + torch.log(torch.sum(torch.exp(all_logits - max_logits), dim=-1, keepdim=True))
|
|
|
|
|
|
|
|
|
loss = - (positive_logits - log_sum_exp.squeeze())
|
|
|
return loss.mean()
|
|
|
|
|
|
|
|
|
def ssim(a, b):
|
|
|
return torch.nn.CosineSimilarity()(a, b)
|
|
|
|
|
|
def augwise_proto(feat_vol, mask, k, aug):
|
|
|
k, aug, c, h, w = k, aug, *feat_vol.shape[-3:]
|
|
|
feature_vectors_augwise = torch.cat(feat_vol.view(k, aug, c, h * w).unbind(0), dim=-1)
|
|
|
mask_augwise = torch.cat(segutils.downsample_mask(mask, h, w).view(k, aug, h * w).unbind(0), dim=-1)
|
|
|
assert feature_vectors_augwise.shape == (aug, c, k * h * w) and mask_augwise.shape == (
|
|
|
aug, k * h * w), "of transformed"
|
|
|
|
|
|
fg_proto, bg_proto = segutils.fg_bg_proto(feature_vectors_augwise, mask_augwise)
|
|
|
assert fg_proto.shape == bg_proto.shape == (aug, c)
|
|
|
|
|
|
return fg_proto, bg_proto
|
|
|
|
|
|
|
|
|
def calc_q_pred_coarse_nodetach(qft, sft, s_mask, l0=3):
|
|
|
bsz, c, hq, wq = qft.shape
|
|
|
hs, ws = sft.shape[-2:]
|
|
|
|
|
|
sft_row = torch.cat(sft.unbind(1), -1)
|
|
|
smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
|
|
|
smask_row = torch.cat(smasks_downsampled, -1)
|
|
|
|
|
|
damat = dautils.buildDenseAffinityMat(qft, sft_row)
|
|
|
filtered = dautils.filterDenseAffinityMap(damat, smask_row)
|
|
|
q_pred_coarse = filtered.view(bsz, hq, wq)
|
|
|
return q_pred_coarse
|
|
|
|
|
|
|
|
|
|
|
|
def self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug):
|
|
|
c, h, w = f_base.shape[-3:]
|
|
|
pseudoquery = torch.cat(f_base.view(k, aug, c, h, w).unbind(0), -1)
|
|
|
pseudoquerymask = torch.cat(mask_base.view(k, aug, h, w).unbind(0), -1)
|
|
|
pseudosupport = f_transformed.view(k, aug, c, h, w).transpose(0, 1)
|
|
|
pseudosupportmask = mask_transformed.view(k, aug, h, w).transpose(0, 1)
|
|
|
|
|
|
pred_map = calc_q_pred_coarse_nodetach(pseudoquery, pseudosupport, pseudosupportmask, l0=0)
|
|
|
|
|
|
loss = torch.nn.BCELoss()(pred_map.float(), pseudoquerymask.float())
|
|
|
return loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ctrstive_prototype_loss(base, transformed, mask_base, mask_transformed, k, aug):
|
|
|
assert transformed.shape == base.shape, ".."
|
|
|
b, c, h, w = base.shape
|
|
|
assert b == k * aug, 'provide correct k and aug such that dim0=k*aug'
|
|
|
assert mask_base.shape == mask_transformed.shape == (b, h, w), ".."
|
|
|
fg_proto_o, bg_proto_o = augwise_proto(base, mask_base, k, aug)
|
|
|
fg_proto_t, bg_proto_t = augwise_proto(transformed, mask_transformed, k, aug)
|
|
|
|
|
|
|
|
|
|
|
|
enumer = torch.exp(
|
|
|
ssim(fg_proto_o, fg_proto_t))
|
|
|
denom = torch.exp(ssim(fg_proto_o, fg_proto_t)) + torch.exp(ssim(fg_proto_o, bg_proto_t))
|
|
|
assert enumer.shape == denom.shape == torch.Size([aug]), 'you want to calculate one prototype for each augmentation'
|
|
|
loss = -torch.log(enumer / denom)
|
|
|
return loss.mean()
|
|
|
|
|
|
|
|
|
def opposite_proto_sim_in_aug(transformed_features, mapped_s_masks, k, aug):
|
|
|
fg_proto_t, bg_proto_t = augwise_proto(transformed_features, mapped_s_masks, k, aug)
|
|
|
fg_bg_sim_t = ssim(fg_proto_t, bg_proto_t)
|
|
|
return fg_bg_sim_t.mean()
|
|
|
|
|
|
|
|
|
def proto_align_val_measure(original_features, transformed_features, mapped_s_masks, k, aug):
|
|
|
fg_proto_o, _ = augwise_proto(original_features, mapped_s_masks, k, aug)
|
|
|
fg_proto_t, _ = augwise_proto(transformed_features, mapped_s_masks, k, aug)
|
|
|
fg_proto_sim = ssim(fg_proto_o, fg_proto_t)
|
|
|
return fg_proto_sim.mean()
|
|
|
|
|
|
|
|
|
def atest():
|
|
|
k, aug, c, h, w = 2, 5, 8, 20, 20
|
|
|
f_base = torch.rand(k * aug, c, h, w).float()
|
|
|
f_base.requires_grad = True
|
|
|
f_transformed = torch.rand(k * aug, c, h, w).float()
|
|
|
mask_base = torch.randint(0, 2, (k * aug, h, w)).float()
|
|
|
mask_transformed = torch.randint(0, 2, (k * aug, h, w)).float()
|
|
|
|
|
|
return self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug)
|
|
|
|
|
|
def keep_var_loss(original_features, transformed_features):
|
|
|
meandiff = original_features.mean((-2, -1)) - transformed_features.mean((-2, -1))
|
|
|
vardiff = original_features.var((-2, -1)) - transformed_features.var((-2, -1))
|
|
|
keepvarloss = torch.abs(meandiff).mean() + torch.abs(
|
|
|
vardiff).mean()
|
|
|
return keepvarloss
|
|
|
|
|
|
class ContrastiveFeatureTransformer(nn.Module):
|
|
|
def __init__(self, in_channels, config_model):
|
|
|
super(ContrastiveFeatureTransformer, self).__init__()
|
|
|
|
|
|
out_channels, kernel_size = config_model.out_channels, config_model.kernel_size
|
|
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
self.linear = nn.Conv2d(out_channels, out_channels, 1)
|
|
|
|
|
|
self.prepend_relu = config_model.prepend_relu
|
|
|
self.append_normalize = config_model.append_normalize
|
|
|
self.debug = config_model.debug
|
|
|
|
|
|
def forward(self, x):
|
|
|
if self.prepend_relu:
|
|
|
x = nn.ReLU()(x)
|
|
|
x = self.conv(x)
|
|
|
x = self.bn(x)
|
|
|
x = nn.ReLU()(x)
|
|
|
x = self.linear(x)
|
|
|
if self.append_normalize:
|
|
|
x = F.normalize(x, p=2, dim=1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, mapped_qfeat_vol, aug_qfeat_vols, mapped_sfeat_vol, aug_sfeat_vols, augmented_smasks, config_fit):
|
|
|
f_norm = F.normalize if config_fit.normalize_after_fwd_pass else identity_mapping
|
|
|
optimizer = config_fit.optimizer(self.parameters(), lr=config_fit.lr)
|
|
|
for epoch in range(config_fit.num_epochs):
|
|
|
|
|
|
|
|
|
|
|
|
original_features = f_norm(self(mapped_qfeat_vol), p=2, dim=1)
|
|
|
transformed_features = f_norm(self(aug_qfeat_vols), p=2, dim=1)
|
|
|
|
|
|
qloss = dense_info_nce_loss(original_features, transformed_features,
|
|
|
config_fit.nce) if config_fit.q_nceloss else 0
|
|
|
if config_fit.keepvarloss:
|
|
|
qloss += keep_var_loss(original_features, transformed_features)
|
|
|
|
|
|
original_features = f_norm(self(mapped_sfeat_vol), p=2, dim=1)
|
|
|
transformed_features = f_norm(self(aug_sfeat_vols), p=2, dim=1)
|
|
|
|
|
|
sloss = dense_info_nce_loss(original_features, transformed_features,
|
|
|
config_fit.nce) if config_fit.s_nceloss else 0
|
|
|
if config_fit.keepvarloss:
|
|
|
sloss += keep_var_loss(original_features, transformed_features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kaug, c, h, w = transformed_features.shape
|
|
|
aug = aug_qfeat_vols.shape[0]
|
|
|
k = kaug // aug
|
|
|
if config_fit.protoloss:
|
|
|
assert not config_fit.o_t_contr_proto_loss, 'only one of the proto losses should be used'
|
|
|
opposite_proto_sim = opposite_proto_sim_in_aug(transformed_features, augmented_smasks, k, aug)
|
|
|
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
|
|
'proto-sim intER-class transf<->transf', opposite_proto_sim.item())
|
|
|
proto_loss = opposite_proto_sim
|
|
|
elif config_fit.selfattentionloss:
|
|
|
proto_loss = self_attention_loss(original_features, transformed_features, augmented_smasks,
|
|
|
augmented_smasks, k, aug)
|
|
|
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
|
|
'self-att non-transf<->transformed bce', proto_loss.item())
|
|
|
elif config_fit.o_t_contr_proto_loss:
|
|
|
o_t_contr_proto_loss = ctrstive_prototype_loss(original_features, transformed_features,
|
|
|
augmented_smasks, augmented_smasks, k, aug)
|
|
|
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
|
|
|
'proto-contr non-transf<->transformed', o_t_contr_proto_loss.item())
|
|
|
proto_loss = o_t_contr_proto_loss
|
|
|
else:
|
|
|
proto_loss = 0
|
|
|
|
|
|
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
|
|
|
proto_align_val = proto_align_val_measure(original_features, transformed_features, augmented_smasks, k,
|
|
|
aug)
|
|
|
print('proto-sim intRA-class non-transf<->transformed (for validation)', proto_align_val.item())
|
|
|
|
|
|
|
|
|
q_s_loss_diff = torch.abs(qloss - sloss) if config_fit.symmetricloss else 0
|
|
|
|
|
|
|
|
|
loss = qloss + sloss + q_s_loss_diff + proto_loss
|
|
|
assert loss.isfinite().all(), f"invalid contrastive loss:{loss}"
|
|
|
|
|
|
|
|
|
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
|
|
|
def gradient_magnitude(loss_term):
|
|
|
optimizer.zero_grad()
|
|
|
loss_term.backward(retain_graph=True)
|
|
|
magn = torch.abs(self.conv.weight.grad.mean()) + torch.abs(self.linear.weight.grad.mean())
|
|
|
return magn
|
|
|
|
|
|
q_loss_grad_magnitude = gradient_magnitude(qloss)
|
|
|
s_loss_grad_magnitude = gradient_magnitude(sloss)
|
|
|
proto_loss_grad_magnitude = gradient_magnitude(proto_loss)
|
|
|
q_s_loss_diff_grad_magnitude = gradient_magnitude(q_s_loss_diff)
|
|
|
display(segutils.tensor_table(q_loss_grad_magnitude=q_loss_grad_magnitude,
|
|
|
s_loss_grad_magnitude=s_loss_grad_magnitude,
|
|
|
proto_loss_grad_magnitude=proto_loss_grad_magnitude,
|
|
|
q_s_loss_diff_grad_magnitude=q_s_loss_diff_grad_magnitude))
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
if config_fit.debug and epoch % 10 == 0: print('loss', loss.detach())
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import torch.nn.functional as F
|
|
|
from torchvision.transforms.functional import affine
|
|
|
from torchvision.transforms import GaussianBlur, ColorJitter
|
|
|
|
|
|
|
|
|
class AffineProxy:
|
|
|
def __init__(self, angle, translate, scale, shear):
|
|
|
self.affine_params = {
|
|
|
'angle': angle,
|
|
|
'translate': translate,
|
|
|
'scale': scale,
|
|
|
'shear': shear
|
|
|
}
|
|
|
|
|
|
def apply(self, img):
|
|
|
return affine(img, angle=self.affine_params['angle'], translate=self.affine_params['translate'],
|
|
|
scale=self.affine_params['scale'], shear=self.affine_params['shear'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Augmen:
|
|
|
def __init__(self, config_aug):
|
|
|
self.config = config_aug
|
|
|
self.blurs, self.jitters, self.affines = self.setup_augmentations()
|
|
|
|
|
|
def copy_construct(self, blurs, jitters, affines, config_aug):
|
|
|
self.config = config_aug
|
|
|
self.blurs, self.jitters, self.affines = blurs, jitters, affines
|
|
|
|
|
|
def setup_augmentations(self):
|
|
|
blurkernelsize = self.config.blurkernelsize
|
|
|
maxjitter = self.config.maxjitter
|
|
|
|
|
|
maxangle = self.config.maxangle
|
|
|
translate = (0, 0)
|
|
|
maxscale = self.config.maxscale
|
|
|
maxshear = self.config.maxshear
|
|
|
|
|
|
blurs = []
|
|
|
jitters = []
|
|
|
affine_trans = []
|
|
|
for i in range(self.config.n_transformed_imgs):
|
|
|
|
|
|
kernel_size = np.random.choice(torch.tensor(blurkernelsize), (1,)).item()
|
|
|
blur = GaussianBlur(kernel_size)
|
|
|
blurs.append(blur)
|
|
|
|
|
|
|
|
|
brightness_val = torch.rand(1).item() * maxjitter
|
|
|
contrast_val = torch.rand(1).item() * maxjitter
|
|
|
saturation_val = torch.rand(1).item() * maxjitter
|
|
|
jitter = ColorJitter(brightness=brightness_val, contrast=contrast_val, saturation=saturation_val)
|
|
|
jitters.append(jitter)
|
|
|
|
|
|
|
|
|
angle = torch.randint(-maxangle, maxangle + 1, (1,)).item()
|
|
|
shear = [torch.randint(-maxshear, maxshear + 1, (1,)).item() for _ in range(2)]
|
|
|
scale = torch.rand(1).item() * (1 - maxscale) + maxscale
|
|
|
affine_trans.append(AffineProxy(angle=angle, translate=translate, scale=scale, shear=shear))
|
|
|
|
|
|
return (blurs, jitters, affine_trans)
|
|
|
|
|
|
def augment(self, original_image, orignal_mask):
|
|
|
transformed_imgs = []
|
|
|
transformed_masks = []
|
|
|
for blur, jitter, affine_trans in zip(self.blurs, self.jitters, self.affines):
|
|
|
|
|
|
t_img = blur(original_image)
|
|
|
t_img = jitter(t_img)
|
|
|
t_mask = orignal_mask.clone()
|
|
|
|
|
|
if self.config.apply_affine:
|
|
|
t_img = affine_trans.apply(t_img)
|
|
|
t_mask = affine_trans.apply(t_mask)
|
|
|
|
|
|
transformed_imgs.append(t_img)
|
|
|
transformed_masks.append(t_mask)
|
|
|
return torch.stack(transformed_imgs, dim=1), torch.stack(transformed_masks, dim=1)
|
|
|
|
|
|
|
|
|
def applyAffines(self, feat_vol):
|
|
|
return torch.stack([trans.apply(feat_vol) for trans in self.affines], dim=1)
|
|
|
|
|
|
|
|
|
class CTrBuilder:
|
|
|
|
|
|
def __init__(self, config, augmentator=None):
|
|
|
if augmentator is None:
|
|
|
augmentator = Augmen(config.aug)
|
|
|
self.augmentator = augmentator
|
|
|
|
|
|
self.augimgs = self.AugImgStack(augmentator)
|
|
|
|
|
|
self.hasfit = False
|
|
|
self.config = config
|
|
|
|
|
|
class AugImgStack():
|
|
|
def __init__(self, augmentator):
|
|
|
self.augmentator = augmentator
|
|
|
self.q, self.s, self.s_mask = None, None, None
|
|
|
|
|
|
def init(self, s_img):
|
|
|
|
|
|
bsz, k, aug, c, h, w = *s_img.shape[:2], self.augmentator.config.n_transformed_imgs, *s_img.shape[-3:]
|
|
|
self.q = torch.empty(bsz, aug, c, h, w).to(s_img.device)
|
|
|
self.s = torch.empty(bsz, k, aug, c, h, w).to(s_img.device)
|
|
|
self.s_mask = torch.empty(bsz, k, aug, h, w).to(s_img.device)
|
|
|
|
|
|
def show(self):
|
|
|
bsz_, k_, aug_ = self.s.shape[:3]
|
|
|
for b in range(bsz_):
|
|
|
display('aug x queries', segutils.pilImageRow(*[segutils.norm(img) for img in self.q[b]]))
|
|
|
for k in range(k_):
|
|
|
print('k=', k, ' aug x (s, smask):')
|
|
|
display(segutils.pilImageRow(*[segutils.norm(img) for img in self.s[b, k]]))
|
|
|
display(segutils.pilImageRow(*self.s_mask[b, k]))
|
|
|
|
|
|
def showAugmented(self):
|
|
|
self.augimgs.show()
|
|
|
|
|
|
|
|
|
def makeAugmented(self, q_img, s_img, s_mask):
|
|
|
|
|
|
|
|
|
self.augimgs.init(s_img)
|
|
|
self.augimgs.q, _ = self.augmentator.augment(q_img, s_mask)
|
|
|
|
|
|
for k in range(s_img.shape[1]):
|
|
|
s_aug_imgs, s_aug_masks = self.augmentator.augment(s_img[:, k], s_mask[:, k])
|
|
|
self.augimgs.s[:, k] = s_aug_imgs
|
|
|
self.augimgs.s_mask[:, k] = s_aug_masks
|
|
|
if self.config.aug.debug: self.augimgs.show()
|
|
|
|
|
|
|
|
|
def build_and_fit(self, q_feat, s_feat, q_feataug, s_feataug, s_maskaug=None):
|
|
|
if s_maskaug is None: s_maskaug = self.augimgs.s_mask
|
|
|
self.ctrs = self.buildContrastiveTransformers(q_feat, s_feat, q_feataug, s_feataug, s_maskaug)
|
|
|
self.hasfit = True
|
|
|
|
|
|
def buildContrastiveTransformers(self, qfeat_alllayers, sfeat_alllayers, query_feats_aug, support_feats_aug,
|
|
|
supp_aug_mask, s_mask=None):
|
|
|
contrastive_transformers = []
|
|
|
l0 = self.config.featext.l0
|
|
|
|
|
|
s_aug_mask = supp_aug_mask.view(-1, *supp_aug_mask.shape[-2:])
|
|
|
|
|
|
for (qfeat, sfeat, qfeataug, sfeataug) in zip(qfeat_alllayers[l0:], sfeat_alllayers[l0:], query_feats_aug[l0:],
|
|
|
support_feats_aug[l0:]):
|
|
|
bsz, k, aug, ch, h, w = sfeataug.shape
|
|
|
|
|
|
assert bsz == 1, "bsz should be 1"
|
|
|
assert supp_aug_mask.shape[1] == sfeat.shape[
|
|
|
1] == k, f'augmented support shot-dimension mismatch:{s_aug_mask.shape[1]=},{sfeat.shape[1]=},(bsz,k,aug,ch,h,w)={bsz, k, aug, ch, h, w}'
|
|
|
assert supp_aug_mask.shape[2] == qfeataug.shape[1] == aug, 'augmented shot-dimension mismatch'
|
|
|
|
|
|
qfeat = qfeat.view(-1, *qfeat.shape[-3:])
|
|
|
|
|
|
sfeat = sfeat.view(-1, *sfeat.shape[-3:])
|
|
|
|
|
|
qfeataug = qfeataug.view(-1, *qfeataug.shape[-3:])
|
|
|
|
|
|
sfeataug = sfeataug.view(-1, *qfeataug.shape[-3:])
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
contrastive_head = ContrastiveFeatureTransformer(in_channels=ch, config_model=self.config.model).to(device)
|
|
|
|
|
|
|
|
|
mapped_qfeat = self.augmentator.applyAffines(qfeat)
|
|
|
assert mapped_qfeat.shape[1] == aug, "should be 1,aug,c,h,w"
|
|
|
mapped_qfeat = mapped_qfeat.view(-1, *qfeat.shape[-3:])
|
|
|
mapped_sfeat = self.augmentator.applyAffines(sfeat)
|
|
|
assert mapped_sfeat.shape[1] == aug and mapped_sfeat.shape[0] == k, "should be k,aug,c,h,w"
|
|
|
mapped_sfeat = mapped_sfeat.view(-1, *sfeat.shape[-3:])
|
|
|
|
|
|
contrastive_head.fit(mapped_qfeat, qfeataug, mapped_sfeat, sfeataug,
|
|
|
segutils.downsample_mask(s_aug_mask, h, w), self.config.fitting)
|
|
|
|
|
|
contrastive_transformers.append(contrastive_head)
|
|
|
|
|
|
if s_mask != None:
|
|
|
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
|
|
|
dautils.buildDenseAffinityMat(contrastive_head(sfeat), contrastive_head(sfeataug[:1])),
|
|
|
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
|
|
|
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
|
|
|
dautils.buildDenseAffinityMat(contrastive_head(qfeat), contrastive_head(sfeat)),
|
|
|
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
|
|
|
return contrastive_transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def getTaskAdaptedFeats(self, layerwise_feats):
|
|
|
if (self.ctrs == None): print("error: call buildContrastiveTransformers() first")
|
|
|
task_adapted_feats = []
|
|
|
|
|
|
for idx in range(len(layerwise_feats)):
|
|
|
if idx < self.config.featext.l0:
|
|
|
task_adapted_feats.append(None)
|
|
|
else:
|
|
|
input_shape = layerwise_feats[idx].shape
|
|
|
idxth_feat = layerwise_feats[idx].view(-1, *input_shape[-3:])
|
|
|
forward_pass_res = self.ctrs[idx - self.config.featext.l0](idxth_feat)
|
|
|
target_shape = *input_shape[:-3], *forward_pass_res.shape[
|
|
|
-3:]
|
|
|
task_adapted_feats.append(forward_pass_res.view(target_shape))
|
|
|
|
|
|
return task_adapted_feats
|
|
|
|
|
|
|
|
|
class FeatureMaker:
|
|
|
def __init__(self, feat_extraction_method, class_ids, config=ContrastiveConfig()):
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.featextractor = feat_extraction_method
|
|
|
self.c_trs = {ctr: CTrBuilder(config) for ctr in class_ids}
|
|
|
self.norm_bb_feats = False
|
|
|
|
|
|
def extract_bb_feats(self, img):
|
|
|
with torch.no_grad():
|
|
|
return self.featextractor(img)
|
|
|
|
|
|
def create_and_fit(self, c_tr, q_img, s_img, s_mask, q_feat, s_feat):
|
|
|
print('doing contrastive')
|
|
|
c_tr.makeAugmented(q_img, s_img, s_mask)
|
|
|
|
|
|
bsz, k, c, h, w = s_img.shape
|
|
|
aug = c_tr.augmentator.config.n_transformed_imgs
|
|
|
|
|
|
q_feataug = self.extract_bb_feats(c_tr.augimgs.q.view(-1, c, h, w))
|
|
|
|
|
|
q_feataug = [l.view(bsz, aug, *l.shape[1:]) for l in q_feataug]
|
|
|
|
|
|
s_feataug = self.extract_bb_feats(c_tr.augimgs.s.view(-1, c, h, w))
|
|
|
s_feataug = [l.view(bsz, k, aug, *l.shape[1:]) for l in s_feataug]
|
|
|
|
|
|
c_tr.build_and_fit(q_feat, s_feat, q_feataug, s_feataug)
|
|
|
|
|
|
def taskAdapt(self, q_img, s_img, s_mask, class_id):
|
|
|
ch_norm = lambda t: t / torch.linalg.norm(t, dim=1)
|
|
|
q_feat = self.extract_bb_feats(q_img)
|
|
|
bsz, k, c, h, w = s_img.shape
|
|
|
s_feat = self.extract_bb_feats(s_img.view(-1, c, h, w))
|
|
|
if self.norm_bb_feats:
|
|
|
q_feat = [ch_norm(l) for l in q_feat]
|
|
|
s_feat = [ch_norm(l) for l in q_feat]
|
|
|
s_feat = [l.view(bsz, k, *l.shape[1:]) for l in s_feat]
|
|
|
|
|
|
c_tr = self.c_trs[class_id]
|
|
|
|
|
|
if c_tr.hasfit is False or c_tr.config.featext.fit_every_episode:
|
|
|
self.create_and_fit(c_tr, q_img, s_img, s_mask, q_feat, s_feat)
|
|
|
|
|
|
q_feat_t, s_feat_t = c_tr.getTaskAdaptedFeats(q_feat), c_tr.getTaskAdaptedFeats(
|
|
|
s_feat)
|
|
|
return q_feat_t, s_feat_t
|
|
|
|