DTECT / backend /models /CFDTM /ETC.py
AdhyaSuman's picture
Initial commit with Git LFS for large files
11c72a2
import torch
import torch.nn as nn
import torch.nn.functional as F
class ETC(nn.Module):
def __init__(self, num_times, temperature, weight_neg, weight_pos):
super().__init__()
self.num_times = num_times
self.weight_neg = weight_neg
self.weight_pos = weight_pos
self.temperature = temperature
def forward(self, topic_embeddings):
loss = 0.
loss_neg = 0.
loss_pos = 0.
for t in range(self.num_times):
loss_neg += self.compute_loss(topic_embeddings[t], topic_embeddings[t], self.temperature, self_contrast=True)
for t in range(1, self.num_times):
loss_pos += self.compute_loss(topic_embeddings[t], topic_embeddings[t - 1].detach(), self.temperature, self_contrast=False, only_pos=True)
loss_neg *= (self.weight_neg / self.num_times)
loss_pos *= (self.weight_pos / (self.num_times - 1))
loss = loss_neg + loss_pos
return loss
def compute_loss(self, anchor_feature, contrast_feature, temperature, self_contrast=False, only_pos=False, all_neg=False):
# KxK
anchor_dot_contrast = torch.div(
torch.matmul(F.normalize(anchor_feature, dim=1), F.normalize(contrast_feature, dim=1).T),
temperature
)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
pos_mask = torch.eye(anchor_dot_contrast.shape[0]).to(anchor_dot_contrast.device)
if self_contrast is False:
if only_pos is False:
if all_neg is True:
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(1)
log_prob = -torch.log(sum_exp_logits + 1e-12)
mean_log_prob = -log_prob.sum() / (logits.shape[0] * logits.shape[1])
else:
# only pos
mean_log_prob = -(logits * pos_mask).sum() / pos_mask.sum()
else:
# self contrast: push away from each other in the same time slice.
exp_logits = torch.exp(logits) * (1 - pos_mask)
sum_exp_logits = exp_logits.sum(1)
log_prob = -torch.log(sum_exp_logits + 1e-12)
mean_log_prob = -log_prob.sum() / (1 - pos_mask).sum()
return mean_log_prob