ABCDFSS / core /contrastivehead.py
heyoujue's picture
add submission code
322161a
import torch.nn.functional as F
import torch
import torch.nn as nn
from utils import segutils
import core.denseaffinity as dautils
identity_mapping = lambda x, *args, **kwargs: x
class ContrastiveConfig:
def __init__(self, config=None):
# Define the internal dictionary with default settings.
if config is None:
self._data = {
'aug': {
'n_transformed_imgs': 2,
'blurkernelsize': [1], # chooses one of this kernel sizes
'maxjitter': 0.0,
'maxangle': 0, # rotation
# 'translate': (0,0), # BE CAREFUL WITH TRANSLATE - if you apply it on the feature volume that has smaller spatial dims correspondences break
'maxscale': 1.0, # 1.0 = No scaling
'maxshear': 20,
'randomhflip': False,
'apply_affine': True,
'debug': False
},
'model': {
'out_channels': 64,
'kernel_size': 1,
'prepend_relu': False,
'append_normalize': False,
'debug': False
},
'fitting': {
'lr': 1e-2,
'optimizer': torch.optim.SGD,
'num_epochs': 25,
'nce': {
'temperature': 0.5,
'debug': False
},
'normalize_after_fwd_pass': True,
'q_nceloss': True,
's_nceloss': True,
'protoloss': False,
'keepvarloss': True,
'symmetricloss': False,
'selfattentionloss': False,
'o_t_contr_proto_loss': True,
'debug': False
},
'featext': {
'l0': 3, # the first resnet bottleneck id to consider (0,1,2,3,4,5...15)
'fit_every_episode': False
}
}
else:
self._data = config
def __getattr__(self, key):
# Try to get '_data' without causing a recursive call to __getattr__
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
if _data is not None and key in _data:
if isinstance(_data[key], dict):
return ContrastiveConfig(_data[key])
return _data[key]
# If we're here, it means the key was not found in the data,
# so we let Python raise the appropriate AttributeError.
raise AttributeError(f"No setting named {key}")
def __setattr__(self, key, value):
# Prevent overwriting of the '_data' attribute by normal means
if key == '_data':
super().__setattr__(key, value)
else:
# Try to get '_data' without causing a recursive call to __getattr__
_data = super().__getattribute__('_data') if '_data' in self.__dict__ else None
if _data is not None:
_data[key] = value
else:
# This situation should not normally occur, handle appropriately (e.g., log an error, raise exception)
raise AttributeError("Unexpected")
# Optional: Representation for better debugging.
def __repr__(self):
return str(self._data)
def dense_info_nce_loss(original_features, transformed_features, config_nce):
B, C, H, W = transformed_features.shape
o_features = original_features.expand(B, C, H, W).permute(0, 2, 3, 1).view(B, H * W, C)
t_features = transformed_features.permute(0, 2, 3, 1).view(B, H * W, C)
# Calculate dot product between original and transformed feature vectors for positive pairs
positive_logits = torch.einsum('bik,bik->bi', o_features, t_features) / config_nce.temperature
# Calculate dot product between original features and all other transformed features for negative pairs
all_logits = torch.einsum('bik,bjk->bij', o_features, t_features) / config_nce.temperature
if config_nce.debug: print('pos/neg:', positive_logits.mean().detach(), all_logits.mean().detach())
# Using the log-sum-exp trick
max_logits = torch.max(all_logits, dim=-1, keepdim=True).values
log_sum_exp = max_logits + torch.log(torch.sum(torch.exp(all_logits - max_logits), dim=-1, keepdim=True))
# Compute InfoNCE loss
loss = - (positive_logits - log_sum_exp.squeeze())
return loss.mean() # [B=k*aug] or [B=k] -> scalar
def ssim(a, b):
return torch.nn.CosineSimilarity()(a, b)
def augwise_proto(feat_vol, mask, k, aug):
k, aug, c, h, w = k, aug, *feat_vol.shape[-3:]
feature_vectors_augwise = torch.cat(feat_vol.view(k, aug, c, h * w).unbind(0), dim=-1)
mask_augwise = torch.cat(segutils.downsample_mask(mask, h, w).view(k, aug, h * w).unbind(0), dim=-1)
assert feature_vectors_augwise.shape == (aug, c, k * h * w) and mask_augwise.shape == (
aug, k * h * w), "of transformed"
fg_proto, bg_proto = segutils.fg_bg_proto(feature_vectors_augwise, mask_augwise)
assert fg_proto.shape == bg_proto.shape == (aug, c)
return fg_proto, bg_proto
def calc_q_pred_coarse_nodetach(qft, sft, s_mask, l0=3):
bsz, c, hq, wq = qft.shape
hs, ws = sft.shape[-2:]
sft_row = torch.cat(sft.unbind(1), -1) # bsz,k,c,h,w -> bsz,c,h,w*k
smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
smask_row = torch.cat(smasks_downsampled, -1)
damat = dautils.buildDenseAffinityMat(qft, sft_row)
filtered = dautils.filterDenseAffinityMap(damat, smask_row)
q_pred_coarse = filtered.view(bsz, hq, wq)
return q_pred_coarse
# input k*aug,c,h,w
def self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug):
c, h, w = f_base.shape[-3:]
pseudoquery = torch.cat(f_base.view(k, aug, c, h, w).unbind(0), -1) # shape aug,c,h,w*k
pseudoquerymask = torch.cat(mask_base.view(k, aug, h, w).unbind(0), -1) # shape aug,h,w*k
pseudosupport = f_transformed.view(k, aug, c, h, w).transpose(0, 1) # shape bsz,k,c,h,w
pseudosupportmask = mask_transformed.view(k, aug, h, w).transpose(0, 1) # shape bsz,k,h,w
# display(segutils.tensor_table(q=pseudoquery, s=pseudosupport, m=pseudosupportmask))
pred_map = calc_q_pred_coarse_nodetach(pseudoquery, pseudosupport, pseudosupportmask, l0=0)
loss = torch.nn.BCELoss()(pred_map.float(), pseudoquerymask.float())
return loss.mean()
# features of base, transformed: [b,c,h,w]
# if base features are aligned with transformed features, pass both same
def ctrstive_prototype_loss(base, transformed, mask_base, mask_transformed, k, aug):
assert transformed.shape == base.shape, ".."
b, c, h, w = base.shape
assert b == k * aug, 'provide correct k and aug such that dim0=k*aug'
assert mask_base.shape == mask_transformed.shape == (b, h, w), ".."
fg_proto_o, bg_proto_o = augwise_proto(base, mask_base, k, aug)
fg_proto_t, bg_proto_t = augwise_proto(transformed, mask_transformed, k, aug)
# i: fg, b: bg
# p_b_i, p_b_j = segutils.fg_bg_proto(base.view(b,c,h*w), mask_base.view(b,h*w))
# p_t_i, p_t_j = segutils.fg_bg_proto(transformed.view(b,c,h*w), mask_transformed.view(b,h*w))
enumer = torch.exp(
ssim(fg_proto_o, fg_proto_t)) # 5vs5 (augvsaug), but in 5-shot: 25vs25, no, you want also augvsaug
denom = torch.exp(ssim(fg_proto_o, fg_proto_t)) + torch.exp(ssim(fg_proto_o, bg_proto_t))
assert enumer.shape == denom.shape == torch.Size([aug]), 'you want to calculate one prototype for each augmentation'
loss = -torch.log(enumer / denom) # [bsz]
return loss.mean()
def opposite_proto_sim_in_aug(transformed_features, mapped_s_masks, k, aug):
fg_proto_t, bg_proto_t = augwise_proto(transformed_features, mapped_s_masks, k, aug)
fg_bg_sim_t = ssim(fg_proto_t, bg_proto_t)
return fg_bg_sim_t.mean()
def proto_align_val_measure(original_features, transformed_features, mapped_s_masks, k, aug):
fg_proto_o, _ = augwise_proto(original_features, mapped_s_masks, k, aug)
fg_proto_t, _ = augwise_proto(transformed_features, mapped_s_masks, k, aug)
fg_proto_sim = ssim(fg_proto_o, fg_proto_t)
return fg_proto_sim.mean()
def atest():
k, aug, c, h, w = 2, 5, 8, 20, 20
f_base = torch.rand(k * aug, c, h, w).float()
f_base.requires_grad = True
f_transformed = torch.rand(k * aug, c, h, w).float()
mask_base = torch.randint(0, 2, (k * aug, h, w)).float()
mask_transformed = torch.randint(0, 2, (k * aug, h, w)).float()
return self_attention_loss(f_base, f_transformed, mask_base, mask_transformed, k, aug)
def keep_var_loss(original_features, transformed_features):
meandiff = original_features.mean((-2, -1)) - transformed_features.mean((-2, -1))
vardiff = original_features.var((-2, -1)) - transformed_features.var((-2, -1))
keepvarloss = torch.abs(meandiff).mean() + torch.abs(
vardiff).mean() # [k*aug,c] -> [scalar] or [aug,c] -> [scalar]
return keepvarloss
class ContrastiveFeatureTransformer(nn.Module):
def __init__(self, in_channels, config_model):
super(ContrastiveFeatureTransformer, self).__init__()
out_channels, kernel_size = config_model.out_channels, config_model.kernel_size
# Add a convolutional layer and a batch normalization layer for learning
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
self.bn = nn.BatchNorm2d(out_channels)
self.linear = nn.Conv2d(out_channels, out_channels, 1)
self.prepend_relu = config_model.prepend_relu
self.append_normalize = config_model.append_normalize
self.debug = config_model.debug
def forward(self, x):
if self.prepend_relu:
x = nn.ReLU()(x)
x = self.conv(x)
x = self.bn(x)
x = nn.ReLU()(x)
x = self.linear(x)
if self.append_normalize:
x = F.normalize(x, p=2, dim=1)
return x
# fits the model for one semantic class, therefore does not work with batches
# mapped_qfeat_vol, aug_qfeat_vols: [aug,c,h,w]
# mapped_sfeat_vol, aug_sfeat_vols: [k*aug,c,h,w]
# augmented_smasks: [k*aug,h,w]
def fit(self, mapped_qfeat_vol, aug_qfeat_vols, mapped_sfeat_vol, aug_sfeat_vols, augmented_smasks, config_fit):
f_norm = F.normalize if config_fit.normalize_after_fwd_pass else identity_mapping
optimizer = config_fit.optimizer(self.parameters(), lr=config_fit.lr)
for epoch in range(config_fit.num_epochs):
# Pass original and transformed image batches through the model
# Q
original_features = f_norm(self(mapped_qfeat_vol), p=2, dim=1) # fwd pass non-augmented
transformed_features = f_norm(self(aug_qfeat_vols), p=2, dim=1) # fwd pass augmented
qloss = dense_info_nce_loss(original_features, transformed_features,
config_fit.nce) if config_fit.q_nceloss else 0
if config_fit.keepvarloss: # 1. idea: Let query and support have the same feature distribution (mean/var per channel)
qloss += keep_var_loss(original_features, transformed_features)
# S
original_features = f_norm(self(mapped_sfeat_vol), p=2, dim=1) # fwd pass non-augmented
transformed_features = f_norm(self(aug_sfeat_vols), p=2, dim=1) # fwd pass augmented
sloss = dense_info_nce_loss(original_features, transformed_features,
config_fit.nce) if config_fit.s_nceloss else 0
if config_fit.keepvarloss:
sloss += keep_var_loss(original_features, transformed_features)
# 2. class-aware loss: opposite classes should get opposite features
# for prototype calculation, we want only one prototype per class
# so we average over features of entire k
# but calculate prototype for each augmentation individually [k*aug,c,h,w]->[aug,c,k*h*w]->[aug,c]
kaug, c, h, w = transformed_features.shape
aug = aug_qfeat_vols.shape[0]
k = kaug // aug
if config_fit.protoloss:
assert not config_fit.o_t_contr_proto_loss, 'only one of the proto losses should be used'
opposite_proto_sim = opposite_proto_sim_in_aug(transformed_features, augmented_smasks, k, aug)
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
'proto-sim intER-class transf<->transf', opposite_proto_sim.item())
proto_loss = opposite_proto_sim
elif config_fit.selfattentionloss:
proto_loss = self_attention_loss(original_features, transformed_features, augmented_smasks,
augmented_smasks, k, aug)
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
'self-att non-transf<->transformed bce', proto_loss.item())
elif config_fit.o_t_contr_proto_loss:
o_t_contr_proto_loss = ctrstive_prototype_loss(original_features, transformed_features,
augmented_smasks, augmented_smasks, k, aug)
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0): print(
'proto-contr non-transf<->transformed', o_t_contr_proto_loss.item())
proto_loss = o_t_contr_proto_loss
else:
proto_loss = 0
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
proto_align_val = proto_align_val_measure(original_features, transformed_features, augmented_smasks, k,
aug)
print('proto-sim intRA-class non-transf<->transformed (for validation)', proto_align_val.item())
# 3. do not let only one image fit well - regularization
q_s_loss_diff = torch.abs(qloss - sloss) if config_fit.symmetricloss else 0
# Aggregate loss
loss = qloss + sloss + q_s_loss_diff + proto_loss
assert loss.isfinite().all(), f"invalid contrastive loss:{loss}"
# Backpropagation and optimization
if config_fit.debug and (epoch == config_fit.num_epochs - 1 or epoch == 0):
def gradient_magnitude(loss_term):
optimizer.zero_grad()
loss_term.backward(retain_graph=True)
magn = torch.abs(self.conv.weight.grad.mean()) + torch.abs(self.linear.weight.grad.mean())
return magn
q_loss_grad_magnitude = gradient_magnitude(qloss)
s_loss_grad_magnitude = gradient_magnitude(sloss)
proto_loss_grad_magnitude = gradient_magnitude(proto_loss)
q_s_loss_diff_grad_magnitude = gradient_magnitude(q_s_loss_diff)
display(segutils.tensor_table(q_loss_grad_magnitude=q_loss_grad_magnitude,
s_loss_grad_magnitude=s_loss_grad_magnitude,
proto_loss_grad_magnitude=proto_loss_grad_magnitude,
q_s_loss_diff_grad_magnitude=q_s_loss_diff_grad_magnitude))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if config_fit.debug and epoch % 10 == 0: print('loss', loss.detach())
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import affine
from torchvision.transforms import GaussianBlur, ColorJitter
class AffineProxy:
def __init__(self, angle, translate, scale, shear):
self.affine_params = {
'angle': angle,
'translate': translate,
'scale': scale,
'shear': shear
}
def apply(self, img):
return affine(img, angle=self.affine_params['angle'], translate=self.affine_params['translate'],
scale=self.affine_params['scale'], shear=self.affine_params['shear'])
# def affine_proxy(angle, translate, scale, shear):
# def inner(img):
# return affine(img, angle=angle, translate=translate, scale=scale, shear=shear)
# return inner
class Augmen:
def __init__(self, config_aug):
self.config = config_aug
self.blurs, self.jitters, self.affines = self.setup_augmentations()
def copy_construct(self, blurs, jitters, affines, config_aug):
self.config = config_aug
self.blurs, self.jitters, self.affines = blurs, jitters, affines
def setup_augmentations(self):
blurkernelsize = self.config.blurkernelsize
maxjitter = self.config.maxjitter
maxangle = self.config.maxangle
translate = (0, 0)
maxscale = self.config.maxscale
maxshear = self.config.maxshear
blurs = []
jitters = []
affine_trans = []
for i in range(self.config.n_transformed_imgs):
# Randomize kernel size for GaussianBlur
kernel_size = np.random.choice(torch.tensor(blurkernelsize), (1,)).item()
blur = GaussianBlur(kernel_size)
blurs.append(blur)
# Randomize values for ColorJitter
brightness_val = torch.rand(1).item() * maxjitter # up to <maxjitter> change
contrast_val = torch.rand(1).item() * maxjitter
saturation_val = torch.rand(1).item() * maxjitter
jitter = ColorJitter(brightness=brightness_val, contrast=contrast_val, saturation=saturation_val)
jitters.append(jitter)
# Random values for each iteration
angle = torch.randint(-maxangle, maxangle + 1, (1,)).item()
shear = [torch.randint(-maxshear, maxshear + 1, (1,)).item() for _ in range(2)]
scale = torch.rand(1).item() * (1 - maxscale) + maxscale
affine_trans.append(AffineProxy(angle=angle, translate=translate, scale=scale, shear=shear))
return (blurs, jitters, affine_trans) # tuple of lists
def augment(self, original_image, orignal_mask):
transformed_imgs = []
transformed_masks = []
for blur, jitter, affine_trans in zip(self.blurs, self.jitters, self.affines):
# Apply non-geometric transformations
t_img = blur(original_image)
t_img = jitter(t_img)
t_mask = orignal_mask.clone()
if self.config.apply_affine:
t_img = affine_trans.apply(t_img)
t_mask = affine_trans.apply(t_mask)
transformed_imgs.append(t_img)
transformed_masks.append(t_mask)
return torch.stack(transformed_imgs, dim=1), torch.stack(transformed_masks, dim=1)
# [bsz,ch,h,w] -> [bsz,aug,ch,h,w], where aug is the number of augmentated images
def applyAffines(self, feat_vol):
return torch.stack([trans.apply(feat_vol) for trans in self.affines], dim=1)
class CTrBuilder:
# call init 1st, pass all config parameters (initatiate a ContrastiveConfig class in your code)
def __init__(self, config, augmentator=None):
if augmentator is None:
augmentator = Augmen(config.aug)
self.augmentator = augmentator
self.augimgs = self.AugImgStack(augmentator)
self.hasfit = False
self.config = config
class AugImgStack():
def __init__(self, augmentator):
self.augmentator = augmentator
self.q, self.s, self.s_mask = None, None, None
def init(self, s_img):
# c is color channels here, not feature channels
bsz, k, aug, c, h, w = *s_img.shape[:2], self.augmentator.config.n_transformed_imgs, *s_img.shape[-3:]
self.q = torch.empty(bsz, aug, c, h, w).to(s_img.device)
self.s = torch.empty(bsz, k, aug, c, h, w).to(s_img.device)
self.s_mask = torch.empty(bsz, k, aug, h, w).to(s_img.device)
def show(self):
bsz_, k_, aug_ = self.s.shape[:3]
for b in range(bsz_):
display('aug x queries', segutils.pilImageRow(*[segutils.norm(img) for img in self.q[b]]))
for k in range(k_):
print('k=', k, ' aug x (s, smask):')
display(segutils.pilImageRow(*[segutils.norm(img) for img in self.s[b, k]]))
display(segutils.pilImageRow(*self.s_mask[b, k]))
def showAugmented(self):
self.augimgs.show()
# 2nd call makeAugmented
def makeAugmented(self, q_img, s_img, s_mask):
# 2. Augmentation
# 2.1 Apply transformations to images
self.augimgs.init(s_img)
self.augimgs.q, _ = self.augmentator.augment(q_img, s_mask)
for k in range(s_img.shape[1]):
s_aug_imgs, s_aug_masks = self.augmentator.augment(s_img[:, k], s_mask[:, k])
self.augimgs.s[:, k] = s_aug_imgs
self.augimgs.s_mask[:, k] = s_aug_masks
if self.config.aug.debug: self.augimgs.show()
# 3rd call build_and_fit
def build_and_fit(self, q_feat, s_feat, q_feataug, s_feataug, s_maskaug=None):
if s_maskaug is None: s_maskaug = self.augimgs.s_mask
self.ctrs = self.buildContrastiveTransformers(q_feat, s_feat, q_feataug, s_feataug, s_maskaug)
self.hasfit = True
def buildContrastiveTransformers(self, qfeat_alllayers, sfeat_alllayers, query_feats_aug, support_feats_aug,
supp_aug_mask, s_mask=None):
contrastive_transformers = []
l0 = self.config.featext.l0
# [bsz,k,aug,h,w] -> [k*aug,h,w]
s_aug_mask = supp_aug_mask.view(-1, *supp_aug_mask.shape[-2:])
# iterate over feature layers
for (qfeat, sfeat, qfeataug, sfeataug) in zip(qfeat_alllayers[l0:], sfeat_alllayers[l0:], query_feats_aug[l0:],
support_feats_aug[l0:]):
bsz, k, aug, ch, h, w = sfeataug.shape
# we fit it for exactly one class, so use no batches
assert bsz == 1, "bsz should be 1"
assert supp_aug_mask.shape[1] == sfeat.shape[
1] == k, f'augmented support shot-dimension mismatch:{s_aug_mask.shape[1]=},{sfeat.shape[1]=},(bsz,k,aug,ch,h,w)={bsz, k, aug, ch, h, w}'
assert supp_aug_mask.shape[2] == qfeataug.shape[1] == aug, 'augmented shot-dimension mismatch'
# [bsz,c,h,w] -> [1,c,h,w]
qfeat = qfeat.view(-1, *qfeat.shape[-3:])
# [bsz,k,c,h,w] -> [k,c,h,w]
sfeat = sfeat.view(-1, *sfeat.shape[-3:])
# [bsz,aug,c,h,w] -> [aug,c,h,w]
qfeataug = qfeataug.view(-1, *qfeataug.shape[-3:])
# [bsz,k,aug,c,h,w] -> [k*aug,c,h,w]
sfeataug = sfeataug.view(-1, *qfeataug.shape[-3:])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
contrastive_head = ContrastiveFeatureTransformer(in_channels=ch, config_model=self.config.model).to(device)
# 3. Feature volumes from untransformed image need to be geometrically mapped to allow for dense matching
mapped_qfeat = self.augmentator.applyAffines(qfeat)
assert mapped_qfeat.shape[1] == aug, "should be 1,aug,c,h,w"
mapped_qfeat = mapped_qfeat.view(-1, *qfeat.shape[-3:]) # ->[aug,c,h,w]
mapped_sfeat = self.augmentator.applyAffines(sfeat)
assert mapped_sfeat.shape[1] == aug and mapped_sfeat.shape[0] == k, "should be k,aug,c,h,w"
mapped_sfeat = mapped_sfeat.view(-1, *sfeat.shape[-3:]) # ->[k*aug,c,h,w]
contrastive_head.fit(mapped_qfeat, qfeataug, mapped_sfeat, sfeataug,
segutils.downsample_mask(s_aug_mask, h, w), self.config.fitting)
contrastive_transformers.append(contrastive_head)
# show how support image and its augmentations would produce a affinity map
if s_mask != None:
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
dautils.buildDenseAffinityMat(contrastive_head(sfeat), contrastive_head(sfeataug[:1])),
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
display(segutils.to_pil(segutils.norm(dautils.filterDenseAffinityMap(
dautils.buildDenseAffinityMat(contrastive_head(qfeat), contrastive_head(sfeat)),
segutils.downsample_mask(s_mask, h, w)).view(1, h, w))))
return contrastive_transformers
# You have fitted the contrastive transformers, now apply the transform and then pass to the downstream DCAMA
# you just need to append the empty layers you exluded ([:3]), they're also skipped in dcama
# Obtain the result of the contrastive head, which will be the new query and support feat representation
def getTaskAdaptedFeats(self, layerwise_feats):
if (self.ctrs == None): print("error: call buildContrastiveTransformers() first")
task_adapted_feats = []
for idx in range(len(layerwise_feats)):
if idx < self.config.featext.l0:
task_adapted_feats.append(None)
else:
input_shape = layerwise_feats[idx].shape
idxth_feat = layerwise_feats[idx].view(-1, *input_shape[-3:])
forward_pass_res = self.ctrs[idx - self.config.featext.l0](idxth_feat)
target_shape = *input_shape[:-3], *forward_pass_res.shape[
-3:] # borrow channel dim from result, but bsz,k dims from input
task_adapted_feats.append(forward_pass_res.view(target_shape))
return task_adapted_feats
class FeatureMaker:
def __init__(self, feat_extraction_method, class_ids, config=ContrastiveConfig()):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.featextractor = feat_extraction_method
self.c_trs = {ctr: CTrBuilder(config) for ctr in class_ids}
self.norm_bb_feats = False
def extract_bb_feats(self, img):
with torch.no_grad():
return self.featextractor(img)
def create_and_fit(self, c_tr, q_img, s_img, s_mask, q_feat, s_feat):
print('doing contrastive')
c_tr.makeAugmented(q_img, s_img, s_mask)
bsz, k, c, h, w = s_img.shape
aug = c_tr.augmentator.config.n_transformed_imgs
# [bsz,aug,c,h,w]->[bsz*aug,c,h,w] squeeze for forward pass
q_feataug = self.extract_bb_feats(c_tr.augimgs.q.view(-1, c, h, w)) # returns layer-list
# then restore
q_feataug = [l.view(bsz, aug, *l.shape[1:]) for l in q_feataug]
# [bsz,k,aug,c,h,w]->[bsz*k*aug,c,h,w]->[bsz,k,aug,c,h,w]
s_feataug = self.extract_bb_feats(c_tr.augimgs.s.view(-1, c, h, w))
s_feataug = [l.view(bsz, k, aug, *l.shape[1:]) for l in s_feataug]
c_tr.build_and_fit(q_feat, s_feat, q_feataug, s_feataug)
def taskAdapt(self, q_img, s_img, s_mask, class_id):
ch_norm = lambda t: t / torch.linalg.norm(t, dim=1)
q_feat = self.extract_bb_feats(q_img)
bsz, k, c, h, w = s_img.shape
s_feat = self.extract_bb_feats(s_img.view(-1, c, h, w))
if self.norm_bb_feats:
q_feat = [ch_norm(l) for l in q_feat]
s_feat = [ch_norm(l) for l in q_feat]
s_feat = [l.view(bsz, k, *l.shape[1:]) for l in s_feat]
c_tr = self.c_trs[class_id] # select the relevant ctr for this class
if c_tr.hasfit is False or c_tr.config.featext.fit_every_episode: # create and fit a contrastive transformer if not existing yet
self.create_and_fit(c_tr, q_img, s_img, s_mask, q_feat, s_feat)
q_feat_t, s_feat_t = c_tr.getTaskAdaptedFeats(q_feat), c_tr.getTaskAdaptedFeats(
s_feat) # tocheck: do they require_grad here?
return q_feat_t, s_feat_t