import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor: v1_norm = torch.norm(v1, dim=1, keepdim=True) v2_norm = torch.norm(v2, dim=1, keepdim=True) v2T = torch.transpose(v2, 0, 1) inner_prod = torch.matmul(v1, v2T) v2_normT = torch.transpose(v2_norm, 0, 1) norm_mat = torch.matmul(v1_norm, v2_normT) loss_mat = torch.div(inner_prod, norm_mat) loss_mat = loss_mat * (1/tau) loss_mat = torch.exp(loss_mat) numerator = torch.diagonal(loss_mat) numerator = torch.unsqueeze(numerator, 0) Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True) Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1) #Lv1_v2_denom = Lv1_v2_denom - numerator Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True) #Lv2_v1_denom = Lv2_v1_denom - numerator Lv1_v2 = torch.div(numerator, Lv1_v2_denom) Lv1_v2 = -1 * torch.log(Lv1_v2) Lv1_v2 = torch.mean(Lv1_v2) Lv2_v1 = torch.div(numerator, Lv2_v1_denom) Lv2_v1 = -1 * torch.log(Lv2_v1) Lv2_v1 = torch.mean(Lv2_v1) return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom) def cand_spec_sim_loss(spec_enc, cand_enc): cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d spec_enc = spec_enc.unsqueeze(0) # 1 x B x d sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2) loss = torch.mean(sim) return loss class cons_spec_loss: def __init__(self, loss_type) -> None: self.loss_compute = {'cosine': self.cos_loss, 'l2':torch.nn.MSELoss()}[loss_type] def __call__(self,cons_spec, ind_spec): return self.loss_compute(cons_spec, ind_spec) def cos_loss(self, cons_spec, ind_spec): sim = nn.functional.cosine_similarity(cons_spec, ind_spec) loss = 1-torch.mean(sim) return loss class fp_loss: def __init__(self, loss_type) -> None: self.loss_compute = {'cosine': self.fp_loss_cos, 'bce': nn.BCELoss()}[loss_type] def __call__(self, predicted_fp, target_fp): return self.loss_compute(predicted_fp, target_fp) def fp_loss_cos(self, predicted_fp, target_fp): sim = nn.functional.cosine_similarity(predicted_fp, target_fp) return 1 - torch.mean(sim) # ---------- Utility ---------- def _safe_divide(num, denom, eps=1e-8): return num / (denom + eps) # ---------- Single-GPU masked FILIP ---------- def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07): """ Single-GPU FILIP loss for modality A (spectra peaks) and modality B (graph nodes), accounting for padding masks. Args: a_tokens: (B, N_a, D) float tensor (will be normalized to unit vectors) b_tokens: (B, N_b, D) mask_a: (B, N_a) bool or byte tensor (True=valid) mask_b: (B, N_b) bool or byte tensor temperature: scalar or 0-dim tensor (learnable ok) Returns: scalar loss """ device = a_tokens.device B, N_a, D = a_tokens.shape N_b = b_tokens.shape[1] # normalize to cos sim a = F.normalize(a_tokens, dim=-1) b = F.normalize(b_tokens, dim=-1) # Expand to compute all pairwise (batch-wise) similarities: # sim shape: (B, B, N_a, N_b) where sim[i,j,k,l] = dot(a[i,k], b[j,l]) a_exp = a.unsqueeze(1).expand(-1, B, -1, -1) # (B, B, N_a, D) b_exp = b.unsqueeze(0).expand(B, -1, -1, -1) # (B, B, N_b, D) sim = torch.einsum('bijd,bitd->bijt', a_exp, b_exp) # (B, B, N_a, N_b) # Expand masks to (B, B, N_a) and (B, B, N_b) mask_a_exp = mask_a.unsqueeze(1).expand(-1, B, -1) # (B, B, N_a) mask_b_exp = mask_b.unsqueeze(0).expand(B, -1, -1) # (B, B, N_b) # ---- A -> B similarity (s_a2b) ---- # For every a-token we need max over valid b-tokens. # Set invalid positions in sim to -inf before max. sim_a2b = sim.clone() invalid_b = ~mask_b_exp.unsqueeze(2).expand(-1, -1, sim_a2b.size(2), -1) # (B, B, N_a, N_b) sim_a2b[invalid_b] = float('-inf') # max over b tokens -> (B, B, N_a) max_over_b = sim_a2b.max(dim=3).values # zero-out padded a-tokens then average over valid tokens max_over_b = max_over_b * mask_a_exp # padded a tokens get zero denom_a = mask_a_exp.sum(dim=2).clamp(min=1).to(sim.dtype) # (B, B) s_a2b = max_over_b.sum(dim=2) / denom_a # (B, B) # ---- B -> A similarity (s_b2a) ---- sim_b2a = sim.clone() invalid_a = ~mask_a_exp.unsqueeze(3).expand(-1,-1,-1,sim_b2a.size(3)) # (B, B, N_a, N_b) sim_b2a[invalid_a] = float('-inf') max_over_a = sim_b2a.max(dim=2).values # (B, B, N_b) max_over_a = max_over_a * mask_b_exp denom_b = mask_b_exp.sum(dim=2).clamp(min=1).to(sim.dtype) s_b2a = max_over_a.sum(dim=2) / denom_b # (B, B) # logits and loss logits_a2b = s_a2b / temperature logits_b2a = s_b2a / temperature labels = torch.arange(B, device=device, dtype=torch.long) loss_a2b = F.cross_entropy(logits_a2b, labels) loss_b2a = F.cross_entropy(logits_b2a, labels) return 0.5 * (loss_a2b + loss_b2a) def global_infonce_loss(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07, agg_fn="mean"): """ Global InfoNCE loss (CLIP-style) for modalities A and B. Args: a_tokens: (B, N_a, D) b_tokens: (B, N_b, D) mask_a: (B, N_a) bool (True = valid) mask_b: (B, N_b) bool (True = valid) temperature: scalar agg_fn: "mean" | "max" | "cls" | callable -> how to aggregate tokens into one vector Returns: scalar loss """ device = a_tokens.device B, N_a, D = a_tokens.shape N_b = b_tokens.shape[1] # ---- Normalize token embeddings ---- a = F.normalize(a_tokens, dim=-1) b = F.normalize(b_tokens, dim=-1) # ---- Aggregate per sample ---- if callable(agg_fn): a_global = agg_fn(a, mask_a) # custom aggregation b_global = agg_fn(b, mask_b) elif agg_fn == "mean": # masked mean a_global = (a * mask_a.unsqueeze(-1)).sum(dim=1) / mask_a.sum(dim=1, keepdim=True).clamp(min=1) b_global = (b * mask_b.unsqueeze(-1)).sum(dim=1) / mask_b.sum(dim=1, keepdim=True).clamp(min=1) elif agg_fn == "max": a_global = (a.masked_fill(~mask_a.unsqueeze(-1), float('-inf'))).max(dim=1).values b_global = (b.masked_fill(~mask_b.unsqueeze(-1), float('-inf'))).max(dim=1).values elif agg_fn == "cls": # use first valid token as "cls" a_global = a[:, 0, :] b_global = b[:, 0, :] else: raise ValueError(f"Unknown agg_fn: {agg_fn}") # ---- Compute cosine similarity matrix ---- a_global = F.normalize(a_global, dim=-1) b_global = F.normalize(b_global, dim=-1) logits = (a_global @ b_global.T) / temperature # (B, B) # ---- InfoNCE loss ---- labels = torch.arange(B, device=device) loss_a2b = F.cross_entropy(logits, labels) loss_b2a = F.cross_entropy(logits.T, labels) loss = 0.5 * (loss_a2b + loss_b2a) return loss # ---------- PCGrad utility ---------- def pcgrad_combine(losses, shared_params): """ Compute PCGrad combined gradient for a list of scalar losses. losses: list of scalar loss tensors shared_params: list of parameters to project/aggregate gradients for returns: scalar combined loss for logging (mean) """ grads_list = [torch.autograd.grad(l, shared_params, retain_graph=True, allow_unused=True) for l in losses] # flatten flat_grads = [torch.cat([g.reshape(-1) for g in grads if g is not None]) for grads in grads_list] projected = [fg.clone() for fg in flat_grads] # project conflicting grads for i in range(len(flat_grads)): for j in range(len(flat_grads)): if i == j: continue dot = (projected[i] * projected[j]).sum() if dot < 0: proj = dot / (projected[j].norm() ** 2 + 1e-12) projected[i] = projected[i] - proj * projected[j] # sum projected grads final_grad = sum(projected) # assign to params pointer = 0 for p in shared_params: if p.requires_grad: numel = p.numel() p.grad = final_grad[pointer:pointer + numel].view_as(p).clone() pointer += numel # return average loss for logging only return sum(losses) / len(losses)