File size: 9,246 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | from abc import ABC
import torch
import torch.nn as nn
class ContrastLoss(nn.Module, ABC):
def __init__(self, hyp_param):
super().__init__()
self.param = hyp_param
_defaults = {
"temperature": 0.10,
"ignore_idx": 255,
"ood_idx": 254,
"max_views": 512,
"proj_dim": 512,
"sample_limits": 128,
"total_limits": 64,
}
_raw = getattr(hyp_param, "contrastive_learning", None) or {}
_cfg = {**_defaults, **_raw}
self.temperature = _cfg["temperature"]
self.ignore_idx = _cfg["ignore_idx"]
self.ood_idx = _cfg["ood_idx"]
self.max_views = _cfg["max_views"]
self.proj_dim = _cfg["proj_dim"]
self.sample_limits = _cfg["sample_limits"]
self.total_limits = _cfg["total_limits"]
def select_class_wise_samples(self, embeddings, audio_embeddings, predictions, masks, batch_idx):
embedding_sample_list = []
label_list = []
embedding_sample_list_a = []
label_list_a = []
class_index_list = torch.unique(masks)
if len(class_index_list) > 1:
for class_index in class_index_list[1:]:
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
label_list_a.append(class_index.unsqueeze(0) + batch_idx * 1e3)
else:
embedding_sample_list_a.append(audio_embeddings.unsqueeze(0))
label_list_a.append(torch.zeros([1], device=embeddings.device) + batch_idx * 1e3)
sample_limits = self.sample_limits
embeddings = embeddings.permute(1, 0)
for class_index in class_index_list:
hard_indices = embeddings[((masks != predictions) & (masks == class_index)).nonzero()]
easy_indices = embeddings[((masks == predictions) & (masks == class_index)).nonzero()]
hard_indices_num, easy_indices_num = hard_indices.shape[0], easy_indices.shape[0]
selective_num_hard = min(sample_limits, hard_indices_num)
selective_num_easy = min(sample_limits, easy_indices_num)
if (selective_num_hard + selective_num_easy) < sample_limits * 2:
if selective_num_hard > selective_num_easy:
selective_num_hard += sample_limits * 2 - selective_num_easy
else:
selective_num_easy += sample_limits * 2 - selective_num_hard
hard_chosen_indices = torch.randperm(hard_indices_num)[:selective_num_hard]
embedding_sample_list.append(hard_indices[hard_chosen_indices])
label_list.append(masks[hard_chosen_indices] + batch_idx * 1e3)
easy_chosen_indices = torch.randperm(easy_indices_num)[:selective_num_easy]
embedding_sample_list.append(easy_indices[easy_chosen_indices])
label_list.append(masks[easy_chosen_indices] + batch_idx * 1e3)
return embedding_sample_list, label_list, embedding_sample_list_a, label_list_a
def forward_audio_visual(self, visual_embeddings, audio_embeddings, masks, predictions):
masks = masks.flatten(start_dim=1)
predictions = predictions.flatten(start_dim=1)
visual_embeddings = visual_embeddings.flatten(start_dim=-2)
visual_embedding_sample_list = []
visual_label_list = []
audio_embedding_sample_list = []
audio_label_list = []
for frame_idx in range(masks.shape[0]):
current_vision_feats = visual_embeddings[frame_idx]
current_masks = masks[frame_idx]
current_predictions = predictions[frame_idx]
current_audio_feats = audio_embeddings[frame_idx]
for layer_idx in range(3):
(
selected_vision_embeddings,
selected_vision_labels,
selected_audio_embeddings,
selected_audio_labels,
) = self.select_class_wise_samples(
current_vision_feats[layer_idx],
current_audio_feats[layer_idx],
current_predictions,
current_masks,
0,
)
visual_embedding_sample_list += selected_vision_embeddings
visual_label_list += selected_vision_labels
audio_embedding_sample_list += selected_audio_embeddings
audio_label_list += selected_audio_labels
if len(visual_embedding_sample_list) == 0:
return 0.0
# Same as artifacts `loss/cl.py`: cat then squeeze. If only one row, squeeze drops batch dim and
# `info_nce` hits "2 vs 1" — keep at least 2D without adding a helper.
visual_embedding_sample_list = torch.cat(visual_embedding_sample_list, dim=0).squeeze()
if visual_embedding_sample_list.dim() == 1:
visual_embedding_sample_list = visual_embedding_sample_list.unsqueeze(0)
visual_label_list = torch.cat(visual_label_list, dim=0).unsqueeze(-1)
audio_embedding_sample_list = torch.cat(audio_embedding_sample_list, dim=0).squeeze()
if audio_embedding_sample_list.dim() == 1:
audio_embedding_sample_list = audio_embedding_sample_list.unsqueeze(0)
audio_label_list = torch.cat(audio_label_list).unsqueeze(1)
total_limits = self.total_limits
if visual_embedding_sample_list.shape[0] > total_limits:
rand_index = torch.randperm(visual_embedding_sample_list.shape[0])[total_limits]
visual_embedding_sample_list = visual_embedding_sample_list[:rand_index]
visual_label_list = visual_label_list[:rand_index]
loss = self.info_nce(
visual_embedding_sample_list,
visual_label_list,
audio_embedding_sample_list,
audio_label_list,
)
return loss
def forward(self, embeddings, output_dicts, masks):
# Align with artifacts `loss/cl.py` forward: squeeze(1) on interp, loop over masks.shape[0], squeeze(-1) on audio.
predictions = torch.cat([i["multistep_pred_masks"] for i in output_dicts])
predictions = torch.nn.functional.interpolate(
predictions,
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
mode="bilinear",
align_corners=False,
).squeeze(1)
masks = torch.nn.functional.interpolate(
masks.unsqueeze(1),
size=(int(self.param.image_size / 16), int(self.param.image_size / 16)),
mode="nearest",
).squeeze(1)
visual_embeddings, audio_embeddings = embeddings
visual_embeddings = torch.cat(
[
torch.cat(
[
visual_embeddings[0][i].unsqueeze(0),
visual_embeddings[1][i].unsqueeze(0),
visual_embeddings[2][i].unsqueeze(0),
]
).unsqueeze(0)
for i in range(masks.shape[0])
]
)
audio_embeddings = torch.cat(
[
torch.cat(
[
audio_embeddings[0][i].unsqueeze(0),
audio_embeddings[1][i].unsqueeze(0),
audio_embeddings[2][i].unsqueeze(0),
]
).unsqueeze(0)
for i in range(masks.shape[0])
]
)
return self.forward_audio_visual(
visual_embeddings, audio_embeddings.squeeze(-1), masks, predictions
)
@staticmethod
def manipulate_cover_mask(a_label, current_mask):
a_label = a_label + 1
visual_mask = torch.matmul(a_label, torch.transpose(a_label, 0, 1))
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 1.0] = 0
current_mask[: visual_mask.shape[1], : visual_mask.shape[0]][visual_mask == 4.0] = 0
return current_mask
def info_nce(self, anchors_, a_labels_, contras_, c_labels_):
c_labels_ = torch.cat([a_labels_, c_labels_])
contras_ = torch.cat([anchors_, contras_])
mask = torch.eq(a_labels_, torch.transpose(c_labels_, 0, 1)).float()
anchor_dot_contrast = torch.div(
torch.matmul(anchors_, torch.transpose(contras_, 0, 1)),
self.temperature,
)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
neg_mask = 1 - mask
mask = self.manipulate_cover_mask(a_label=a_labels_, current_mask=mask)
mask = mask.fill_diagonal_(0.0)
neg_logits = torch.exp(logits) * neg_mask
neg_logits = neg_logits.sum(1, keepdim=True)
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits + neg_logits)
mask_pos_pairs = mask.sum(1)
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
assert not torch.isnan(mean_log_prob_pos).any(), print(torch.isnan(log_prob).any())
return -mean_log_prob_pos.mean()
|