| from torch import nn |
| import torch |
| from root_gnn_base import utils |
| import numpy as np |
|
|
| class MaskedLoss(): |
| def __init__(self, mask = []): |
| self.mask = mask |
|
|
| def make_mask(self, targets): |
| mask = torch.ones_like(targets[:,0]) |
| for m in self.mask: |
| if m['op'] == 'eq': |
| mask[targets[:,m['idx']] == m['val']] = 0 |
| elif m['op'] == 'gt': |
| mask[targets[:,m['idx']] > m['val']] = 0 |
| elif m['op'] == 'lt': |
| mask[targets[:,m['idx']] < m['val']] = 0 |
| elif m['op'] == 'ge': |
| mask[targets[:,m['idx']] >= m['val']] = 0 |
| elif m['op'] == 'le': |
| mask[targets[:,m['idx']] <= m['val']] = 0 |
| elif m['op'] == 'ne': |
| mask[targets[:,m['idx']] != m['val']] = 0 |
| else: |
| raise ValueError(f'Unknown mask op {m["op"]}') |
| return mask == 1 |
|
|
| class MaskedL1Loss(MaskedLoss): |
| def __init__(self, mask = [], index = 0): |
| super().__init__(mask) |
| self.index = index |
| self.loss = nn.L1Loss() |
|
|
| def __call__(self, logits, targets): |
| mask = self.make_mask(targets) |
| return self.loss(logits[mask], targets[mask][:,self.index]) |
|
|
| class BCEWithLogitsLoss(): |
| def __init__(self, weight=None, reduction='mean'): |
| self.loss = nn.BCEWithLogitsLoss(weight=weight, reduction=reduction) |
| |
| def __call__(self, logits, targets): |
| return self.loss(logits[:,0], targets.float()) |
|
|
| class MultiScore(): |
| def __init__(self, scores): |
| self. score_fcns = [] |
| self.start_idx = [] |
| self.end_idx = [] |
| for score in scores: |
| self.score_fcns.append(utils.buildFromConfig(score)) |
| self.start_idx.append(score['start_idx']) |
| self.end_idx.append(score['end_idx']) |
| |
| def __call__(self, last_layer): |
| scores = [] |
| for i in range(len(self.score_fcns)): |
| scores.append(self.score_fcns[i](last_layer[:, self.start_idx[i]:self.end_idx[i]])) |
| return torch.cat(scores, dim=1) |
|
|
| class MultiLoss(): |
| def __init__(self, losses): |
| self.loss_fcns = [] |
| self.label_start_idx = [] |
| self.label_end_idx = [] |
| self.output_start_idx = [] |
| self.output_end_idx = [] |
| self.weights = [] |
| self.label_types = [] |
| for loss in losses: |
| self.loss_fcns.append(utils.buildFromConfig(loss)) |
| self.label_start_idx.append(loss['label_start_idx']) |
| self.label_end_idx.append(loss['label_end_idx']) |
| self.output_start_idx.append(loss['output_start_idx']) |
| self.output_end_idx.append(loss['output_end_idx']) |
| self.weights.append(loss.get('weight', 1.0)) |
| self.label_types.append(loss.get('label_type', 'float')) |
|
|
| def __call__(self, logits, targets): |
| loss = 0 |
| |
| for i in range(len(self.loss_fcns)): |
| if self.label_types[i] == 'int': |
| |
| |
| loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]].to(int)) |
| elif self.label_end_idx[i] - self.label_start_idx[i] == 1: |
| loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]]) |
| else: |
| |
| |
| loss += self.weights[i] * self.loss_fcns[i](logits[:, self.output_start_idx[i]:self.output_end_idx[i]], targets[:, self.label_start_idx[i]:self.label_end_idx[i]]) |
| return loss |
| |
| class AdvLoss(): |
| def __init__(self, loss, adv_loss, adv_weight=1.0): |
| self.loss_fcn = utils.buildFromConfig(loss) |
| self.adv_loss_fcn = utils.buildFromConfig(adv_loss) |
| self.adv_weight = adv_weight |
|
|
| def __call__(self, logits, targets): |
| mask = targets[:,0] == 0 |
| loss = self.loss_fcn(logits[:,0], targets[:,0]) |
| adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask]) |
| return loss - self.adv_weight * adv_loss |
|
|
| class MassWindowAdvLoss(AdvLoss): |
| def __call__(self, logits, targets): |
| mask = (targets[:,0] == 0) & (targets[:,1] > 5) & (targets[:,1] < 25) |
| print(mask, mask.shape, mask.sum()) |
| loss = self.loss_fcn(logits[:,0], targets[:,0]) |
| print(loss) |
| adv_loss = self.adv_loss_fcn(logits[mask][:,1], targets[mask][:,1]) |
| print(adv_loss) |
| return loss - self.adv_weight * adv_loss |
|
|
| class KDELoss(MaskedLoss): |
| def __init__(self, mask = [], index = 0): |
| self.index = index |
| super().__init__(mask) |
|
|
| def __call__(self, logits, targets): |
| mask = self.make_mask(targets) |
| logits = logits[mask] |
| targets = targets[mask][:,self.index] |
| N = logits.shape[0] |
| masses = targets / torch.sqrt(torch.mean(targets**2)) |
| scores = logits[:,0] / torch.sqrt(torch.mean(logits**2)) |
|
|
| factor_2d = (1.0*N) ** (-2/6) |
| covs = (factor_2d * torch.var(masses), factor_2d * torch.var(scores)) |
|
|
| m_diffs = torch.unsqueeze(masses, 1) - torch.unsqueeze(masses, 0) |
| s_diffs = torch.unsqueeze(scores, 1) - torch.unsqueeze(scores, 0) |
|
|
| ymm = torch.exp(- (m_diffs**2) / (4 * covs[0])) |
| yss = torch.exp(- (s_diffs**2) / (4 * covs[1])) |
|
|
| integral_rho_2d_rho_2d = torch.einsum('ij,ij->', ymm, yss) |
| integral_rho_1d_rho_1d = torch.einsum('ij,kl->', ymm, yss) |
| integral_rho_2d_rho_1d = torch.einsum('ij,ik->', ymm, yss) |
| raw_integral = integral_rho_2d_rho_2d - 2 * integral_rho_2d_rho_1d / N + integral_rho_1d_rho_1d / N**2 |
| return raw_integral / (4 * torch.pi * N**2) |
|
|
| class MultiLabelLoss(): |
| def __init__(self, label_names, label_types, label_weights = None): |
| self.loss_fcn = [] |
| if (label_weights): |
| self.weights = torch.tensor(label_weights) |
| else: |
| self.weights = torch.ones(len(label_types)) |
| for type in label_types: |
| if (type == "r"): |
| self.loss_fcn.append(torch.nn.MSELoss(reduce=False)) |
| elif (type == "c"): |
| self.loss_fcn.append(torch.nn.BCEWithLogitsLoss()) |
| print(f"self.weights = {self.weights}") |
|
|
| def __call__(self, logits, targets): |
| targets = targets.float() |
| loss = torch.zeros(len(logits[:, 0]), device = logits.get_device()) |
| for i in range(len(self.loss_fcn)): |
| loss += self.weights[i] * self.loss_fcn[i](logits[:, i], targets[:, i]) |
| return torch.mean(loss) |
| |
| |
| class MultiLabelFinish(): |
| def __init__(self, label_names, label_types): |
| self.finish_fcn = [] |
| for type in label_types: |
| if (type == "r"): |
| self.finish_fcn.append(None) |
| elif (type == "c"): |
| self.finish_fcn.append(torch.special.expit) |
|
|
| def __call__(self, logits): |
| for i in range(len(self.finish_fcn)): |
| if (self.finish_fcn[i]): |
| logits[:, i] = self.finish_fcn[i](logits[:, i].to(torch.long)) |
| return logits |
|
|
| class ContrastiveClusterLoss(): |
| def __init__(self, k=10, temperature=1, alpha=1): |
| self.k = k |
| self.temperature = temperature |
| self.alpha = alpha |
|
|
| def __call__(self, logits, targets): |
| targets = targets.float() |
| logits_combined = logits.float() |
|
|
| hid_size = int(len(logits[0]) / 2) |
|
|
| logits = normalize_embeddings(logits_combined[:, :hid_size]) |
| logits_augmented = normalize_embeddings(logits_combined[:, hid_size:]) |
|
|
| contrastive = contrastive_loss(logits, logits_augmented, self.temperature) |
| clustering, _ = clustering_loss(logits, self.k) |
|
|
| variance_loss = variance_regularization(logits) + variance_regularization(logits_augmented) |
|
|
| return torch.mean(contrastive + clustering + self.alpha * variance_loss) |
| |
| class ContrastiveClusterFinish(): |
| def __init__(self, k = 10, temperature = 1, max_cluster_iterations = 10): |
| self.k = k |
| self.temperature = temperature |
| self.max_cluster_iterations = max_cluster_iterations |
|
|
| print(f"ContrastiveClusterFinish: k = {k}, temperature = {temperature}") |
|
|
| def __call__(self, logits): |
| logits_combined = logits.float() |
|
|
| hid_size = int(len(logits[0]) / 2) |
|
|
| logits = logits_combined[:, :hid_size] |
| logits_augmented = logits_combined[:, hid_size:] |
| |
| contrastive = contrastive_loss(logits, logits_augmented, self.temperature) |
| clustering, _ = clustering_loss(logits, self.k, self.max_cluster_iterations) |
| variance = variance_regularization(logits) + variance_regularization(logits_augmented) |
|
|
| return contrastive, clustering, variance |
| |
| def s(z_i, z_j): |
| z_i = torch.tensor(z_i) if not isinstance(z_i, torch.Tensor) else z_i |
| z_j = torch.tensor(z_j) if not isinstance(z_j, torch.Tensor) else z_j |
| |
| return torch.cdist(z_i, z_j, p=2) |
| |
| |
| |
| |
| |
|
|
| def contrastive_loss(logits, logits_augmented, temperature=1, margin=1.0): |
| logits = torch.tensor(logits) if not isinstance(logits, torch.Tensor) else logits |
| logits_augmented = torch.tensor(logits_augmented) if not isinstance(logits_augmented, torch.Tensor) else logits_augmented |
|
|
| z = torch.cat((logits, logits_augmented), dim=0) |
| similarity_matrix = torch.mm(z, z.t()) / temperature |
| norms = torch.linalg.norm(z, dim=1) |
| norm_matrix = torch.ger(norms, norms) |
| similarity_matrix = similarity_matrix / norm_matrix |
| mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool) |
|
|
| loss = 0 |
| for k in range(len(logits)): |
| numerator = torch.exp(similarity_matrix[k, k + len(logits)]) |
| denominator = torch.sum(torch.exp(similarity_matrix[k, ~mask[k]])) |
| |
| loss += -torch.log(numerator / denominator) |
|
|
| return loss |
|
|
|
|
| def clustering_loss(logits, k=10, max_iterations=10): |
| |
| indices = torch.randperm(logits.size(0))[:k] |
| cluster_means = logits[indices] |
|
|
| prev_assignments = None |
| assignment_history = [] |
| iteration = 0 |
|
|
| while iteration < max_iterations: |
| iteration += 1 |
|
|
| |
| distances = torch.cdist(logits, cluster_means, p=2) |
| cluster_assignments = torch.argmin(distances, dim=1) |
|
|
| |
| if prev_assignments is not None and torch.equal(cluster_assignments, prev_assignments): |
| break |
|
|
| |
| if any(torch.equal(cluster_assignments, prev) for prev in assignment_history): |
| break |
|
|
| assignment_history.append(cluster_assignments.clone()) |
| prev_assignments = cluster_assignments.clone() |
|
|
| |
| new_cluster_means = torch.zeros_like(cluster_means) |
| for i in range(k): |
| assigned_points = logits[cluster_assignments == i] |
| if assigned_points.size(0) > 0: |
| new_cluster_means[i] = assigned_points.mean(dim=0) |
| else: |
| |
| new_cluster_means[i] = logits[torch.randint(0, logits.size(0), (1,)).item()] |
| cluster_means = new_cluster_means |
|
|
| |
| distances = torch.cdist(logits, cluster_means, p=2) |
| min_distances = torch.min(distances, dim=1)[0] |
| loss = torch.sum(min_distances ** 2) |
|
|
| return loss, cluster_means |
|
|
| def normalize_embeddings(embeddings): |
| return embeddings / embeddings.norm(dim=1, keepdim=True) |
|
|
| def variance_regularization(embeddings): |
| mean_embedding = embeddings.mean(dim=0) |
| variance = ((embeddings - mean_embedding) ** 2).mean() |
| return variance |
|
|
|
|