Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor: | |
| v1_norm = torch.norm(v1, dim=1, keepdim=True) | |
| v2_norm = torch.norm(v2, dim=1, keepdim=True) | |
| v2T = torch.transpose(v2, 0, 1) | |
| inner_prod = torch.matmul(v1, v2T) | |
| v2_normT = torch.transpose(v2_norm, 0, 1) | |
| norm_mat = torch.matmul(v1_norm, v2_normT) | |
| loss_mat = torch.div(inner_prod, norm_mat) | |
| loss_mat = loss_mat * (1/tau) | |
| loss_mat = torch.exp(loss_mat) | |
| numerator = torch.diagonal(loss_mat) | |
| numerator = torch.unsqueeze(numerator, 0) | |
| Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True) | |
| Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1) | |
| #Lv1_v2_denom = Lv1_v2_denom - numerator | |
| Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True) | |
| #Lv2_v1_denom = Lv2_v1_denom - numerator | |
| Lv1_v2 = torch.div(numerator, Lv1_v2_denom) | |
| Lv1_v2 = -1 * torch.log(Lv1_v2) | |
| Lv1_v2 = torch.mean(Lv1_v2) | |
| Lv2_v1 = torch.div(numerator, Lv2_v1_denom) | |
| Lv2_v1 = -1 * torch.log(Lv2_v1) | |
| Lv2_v1 = torch.mean(Lv2_v1) | |
| return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom) | |
| def cand_spec_sim_loss(spec_enc, cand_enc): | |
| cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d | |
| spec_enc = spec_enc.unsqueeze(0) # 1 x B x d | |
| sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2) | |
| loss = torch.mean(sim) | |
| return loss | |
| class cons_spec_loss: | |
| def __init__(self, loss_type) -> None: | |
| self.loss_compute = {'cosine': self.cos_loss, | |
| 'l2':torch.nn.MSELoss()}[loss_type] | |
| def __call__(self,cons_spec, ind_spec): | |
| return self.loss_compute(cons_spec, ind_spec) | |
| def cos_loss(self, cons_spec, ind_spec): | |
| sim = nn.functional.cosine_similarity(cons_spec, ind_spec) | |
| loss = 1-torch.mean(sim) | |
| return loss | |
| class fp_loss: | |
| def __init__(self, loss_type) -> None: | |
| self.loss_compute = {'cosine': self.fp_loss_cos, | |
| 'bce': nn.BCELoss()}[loss_type] | |
| def __call__(self, predicted_fp, target_fp): | |
| return self.loss_compute(predicted_fp, target_fp) | |
| def fp_loss_cos(self, predicted_fp, target_fp): | |
| sim = nn.functional.cosine_similarity(predicted_fp, target_fp) | |
| return 1 - torch.mean(sim) | |
| # ---------- Utility ---------- | |
| def _safe_divide(num, denom, eps=1e-8): | |
| return num / (denom + eps) | |
| # ---------- Single-GPU masked FILIP ---------- | |
| def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07): | |
| """ | |
| Single-GPU FILIP loss for modality A (spectra peaks) and modality B (graph nodes), | |
| accounting for padding masks. | |
| Args: | |
| a_tokens: (B, N_a, D) float tensor (will be normalized to unit vectors) | |
| b_tokens: (B, N_b, D) | |
| mask_a: (B, N_a) bool or byte tensor (True=valid) | |
| mask_b: (B, N_b) bool or byte tensor | |
| temperature: scalar or 0-dim tensor (learnable ok) | |
| Returns: | |
| scalar loss | |
| """ | |
| device = a_tokens.device | |
| B, N_a, D = a_tokens.shape | |
| N_b = b_tokens.shape[1] | |
| # normalize to cos sim | |
| a = F.normalize(a_tokens, dim=-1) | |
| b = F.normalize(b_tokens, dim=-1) | |
| # Expand to compute all pairwise (batch-wise) similarities: | |
| # sim shape: (B, B, N_a, N_b) where sim[i,j,k,l] = dot(a[i,k], b[j,l]) | |
| a_exp = a.unsqueeze(1).expand(-1, B, -1, -1) # (B, B, N_a, D) | |
| b_exp = b.unsqueeze(0).expand(B, -1, -1, -1) # (B, B, N_b, D) | |
| sim = torch.einsum('bijd,bitd->bijt', a_exp, b_exp) # (B, B, N_a, N_b) | |
| # Expand masks to (B, B, N_a) and (B, B, N_b) | |
| mask_a_exp = mask_a.unsqueeze(1).expand(-1, B, -1) # (B, B, N_a) | |
| mask_b_exp = mask_b.unsqueeze(0).expand(B, -1, -1) # (B, B, N_b) | |
| # ---- A -> B similarity (s_a2b) ---- | |
| # For every a-token we need max over valid b-tokens. | |
| # Set invalid positions in sim to -inf before max. | |
| sim_a2b = sim.clone() | |
| invalid_b = ~mask_b_exp.unsqueeze(2).expand(-1, -1, sim_a2b.size(2), -1) # (B, B, N_a, N_b) | |
| sim_a2b[invalid_b] = float('-inf') | |
| # max over b tokens -> (B, B, N_a) | |
| max_over_b = sim_a2b.max(dim=3).values | |
| # zero-out padded a-tokens then average over valid tokens | |
| max_over_b = max_over_b * mask_a_exp # padded a tokens get zero | |
| denom_a = mask_a_exp.sum(dim=2).clamp(min=1).to(sim.dtype) # (B, B) | |
| s_a2b = max_over_b.sum(dim=2) / denom_a # (B, B) | |
| # ---- B -> A similarity (s_b2a) ---- | |
| sim_b2a = sim.clone() | |
| invalid_a = ~mask_a_exp.unsqueeze(3).expand(-1,-1,-1,sim_b2a.size(3)) # (B, B, N_a, N_b) | |
| sim_b2a[invalid_a] = float('-inf') | |
| max_over_a = sim_b2a.max(dim=2).values # (B, B, N_b) | |
| max_over_a = max_over_a * mask_b_exp | |
| denom_b = mask_b_exp.sum(dim=2).clamp(min=1).to(sim.dtype) | |
| s_b2a = max_over_a.sum(dim=2) / denom_b # (B, B) | |
| # logits and loss | |
| logits_a2b = s_a2b / temperature | |
| logits_b2a = s_b2a / temperature | |
| labels = torch.arange(B, device=device, dtype=torch.long) | |
| loss_a2b = F.cross_entropy(logits_a2b, labels) | |
| loss_b2a = F.cross_entropy(logits_b2a, labels) | |
| return 0.5 * (loss_a2b + loss_b2a) | |
| def global_infonce_loss(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07, agg_fn="mean"): | |
| """ | |
| Global InfoNCE loss (CLIP-style) for modalities A and B. | |
| Args: | |
| a_tokens: (B, N_a, D) | |
| b_tokens: (B, N_b, D) | |
| mask_a: (B, N_a) bool (True = valid) | |
| mask_b: (B, N_b) bool (True = valid) | |
| temperature: scalar | |
| agg_fn: "mean" | "max" | "cls" | callable -> how to aggregate tokens into one vector | |
| Returns: | |
| scalar loss | |
| """ | |
| device = a_tokens.device | |
| B, N_a, D = a_tokens.shape | |
| N_b = b_tokens.shape[1] | |
| # ---- Normalize token embeddings ---- | |
| a = F.normalize(a_tokens, dim=-1) | |
| b = F.normalize(b_tokens, dim=-1) | |
| # ---- Aggregate per sample ---- | |
| if callable(agg_fn): | |
| a_global = agg_fn(a, mask_a) # custom aggregation | |
| b_global = agg_fn(b, mask_b) | |
| elif agg_fn == "mean": | |
| # masked mean | |
| a_global = (a * mask_a.unsqueeze(-1)).sum(dim=1) / mask_a.sum(dim=1, keepdim=True).clamp(min=1) | |
| b_global = (b * mask_b.unsqueeze(-1)).sum(dim=1) / mask_b.sum(dim=1, keepdim=True).clamp(min=1) | |
| elif agg_fn == "max": | |
| a_global = (a.masked_fill(~mask_a.unsqueeze(-1), float('-inf'))).max(dim=1).values | |
| b_global = (b.masked_fill(~mask_b.unsqueeze(-1), float('-inf'))).max(dim=1).values | |
| elif agg_fn == "cls": | |
| # use first valid token as "cls" | |
| a_global = a[:, 0, :] | |
| b_global = b[:, 0, :] | |
| else: | |
| raise ValueError(f"Unknown agg_fn: {agg_fn}") | |
| # ---- Compute cosine similarity matrix ---- | |
| a_global = F.normalize(a_global, dim=-1) | |
| b_global = F.normalize(b_global, dim=-1) | |
| logits = (a_global @ b_global.T) / temperature # (B, B) | |
| # ---- InfoNCE loss ---- | |
| labels = torch.arange(B, device=device) | |
| loss_a2b = F.cross_entropy(logits, labels) | |
| loss_b2a = F.cross_entropy(logits.T, labels) | |
| loss = 0.5 * (loss_a2b + loss_b2a) | |
| return loss | |
| # ---------- PCGrad utility ---------- | |
| def pcgrad_combine(losses, shared_params): | |
| """ | |
| Compute PCGrad combined gradient for a list of scalar losses. | |
| losses: list of scalar loss tensors | |
| shared_params: list of parameters to project/aggregate gradients for | |
| returns: scalar combined loss for logging (mean) | |
| """ | |
| grads_list = [torch.autograd.grad(l, shared_params, retain_graph=True, allow_unused=True) | |
| for l in losses] | |
| # flatten | |
| flat_grads = [torch.cat([g.reshape(-1) for g in grads if g is not None]) for grads in grads_list] | |
| projected = [fg.clone() for fg in flat_grads] | |
| # project conflicting grads | |
| for i in range(len(flat_grads)): | |
| for j in range(len(flat_grads)): | |
| if i == j: | |
| continue | |
| dot = (projected[i] * projected[j]).sum() | |
| if dot < 0: | |
| proj = dot / (projected[j].norm() ** 2 + 1e-12) | |
| projected[i] = projected[i] - proj * projected[j] | |
| # sum projected grads | |
| final_grad = sum(projected) | |
| # assign to params | |
| pointer = 0 | |
| for p in shared_params: | |
| if p.requires_grad: | |
| numel = p.numel() | |
| p.grad = final_grad[pointer:pointer + numel].view_as(p).clone() | |
| pointer += numel | |
| # return average loss for logging only | |
| return sum(losses) / len(losses) | |