FLARE / flare /utils /loss.py
yzhouchen001's picture
update
19a4dfc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
def contrastive_loss(v1, v2, tau=1.0) -> torch.Tensor:
v1_norm = torch.norm(v1, dim=1, keepdim=True)
v2_norm = torch.norm(v2, dim=1, keepdim=True)
v2T = torch.transpose(v2, 0, 1)
inner_prod = torch.matmul(v1, v2T)
v2_normT = torch.transpose(v2_norm, 0, 1)
norm_mat = torch.matmul(v1_norm, v2_normT)
loss_mat = torch.div(inner_prod, norm_mat)
loss_mat = loss_mat * (1/tau)
loss_mat = torch.exp(loss_mat)
numerator = torch.diagonal(loss_mat)
numerator = torch.unsqueeze(numerator, 0)
Lv1_v2_denom = torch.sum(loss_mat, dim=1, keepdim=True)
Lv1_v2_denom = torch.transpose(Lv1_v2_denom, 0, 1)
#Lv1_v2_denom = Lv1_v2_denom - numerator
Lv2_v1_denom = torch.sum(loss_mat, dim=0, keepdim=True)
#Lv2_v1_denom = Lv2_v1_denom - numerator
Lv1_v2 = torch.div(numerator, Lv1_v2_denom)
Lv1_v2 = -1 * torch.log(Lv1_v2)
Lv1_v2 = torch.mean(Lv1_v2)
Lv2_v1 = torch.div(numerator, Lv2_v1_denom)
Lv2_v1 = -1 * torch.log(Lv2_v1)
Lv2_v1 = torch.mean(Lv2_v1)
return Lv1_v2 + Lv2_v1 , torch.mean(numerator), torch.mean(Lv1_v2_denom+Lv2_v1_denom)
def cand_spec_sim_loss(spec_enc, cand_enc):
cand_enc = torch.transpose(cand_enc, 0, 1) # C x B x d
spec_enc = spec_enc.unsqueeze(0) # 1 x B x d
sim = nn.functional.cosine_similarity(spec_enc, cand_enc, dim=2)
loss = torch.mean(sim)
return loss
class cons_spec_loss:
def __init__(self, loss_type) -> None:
self.loss_compute = {'cosine': self.cos_loss,
'l2':torch.nn.MSELoss()}[loss_type]
def __call__(self,cons_spec, ind_spec):
return self.loss_compute(cons_spec, ind_spec)
def cos_loss(self, cons_spec, ind_spec):
sim = nn.functional.cosine_similarity(cons_spec, ind_spec)
loss = 1-torch.mean(sim)
return loss
class fp_loss:
def __init__(self, loss_type) -> None:
self.loss_compute = {'cosine': self.fp_loss_cos,
'bce': nn.BCELoss()}[loss_type]
def __call__(self, predicted_fp, target_fp):
return self.loss_compute(predicted_fp, target_fp)
def fp_loss_cos(self, predicted_fp, target_fp):
sim = nn.functional.cosine_similarity(predicted_fp, target_fp)
return 1 - torch.mean(sim)
# ---------- Utility ----------
def _safe_divide(num, denom, eps=1e-8):
return num / (denom + eps)
# ---------- Single-GPU masked FILIP ----------
def filip_loss_with_mask(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07):
"""
Single-GPU FILIP loss for modality A (spectra peaks) and modality B (graph nodes),
accounting for padding masks.
Args:
a_tokens: (B, N_a, D) float tensor (will be normalized to unit vectors)
b_tokens: (B, N_b, D)
mask_a: (B, N_a) bool or byte tensor (True=valid)
mask_b: (B, N_b) bool or byte tensor
temperature: scalar or 0-dim tensor (learnable ok)
Returns:
scalar loss
"""
device = a_tokens.device
B, N_a, D = a_tokens.shape
N_b = b_tokens.shape[1]
# normalize to cos sim
a = F.normalize(a_tokens, dim=-1)
b = F.normalize(b_tokens, dim=-1)
# Expand to compute all pairwise (batch-wise) similarities:
# sim shape: (B, B, N_a, N_b) where sim[i,j,k,l] = dot(a[i,k], b[j,l])
a_exp = a.unsqueeze(1).expand(-1, B, -1, -1) # (B, B, N_a, D)
b_exp = b.unsqueeze(0).expand(B, -1, -1, -1) # (B, B, N_b, D)
sim = torch.einsum('bijd,bitd->bijt', a_exp, b_exp) # (B, B, N_a, N_b)
# Expand masks to (B, B, N_a) and (B, B, N_b)
mask_a_exp = mask_a.unsqueeze(1).expand(-1, B, -1) # (B, B, N_a)
mask_b_exp = mask_b.unsqueeze(0).expand(B, -1, -1) # (B, B, N_b)
# ---- A -> B similarity (s_a2b) ----
# For every a-token we need max over valid b-tokens.
# Set invalid positions in sim to -inf before max.
sim_a2b = sim.clone()
invalid_b = ~mask_b_exp.unsqueeze(2).expand(-1, -1, sim_a2b.size(2), -1) # (B, B, N_a, N_b)
sim_a2b[invalid_b] = float('-inf')
# max over b tokens -> (B, B, N_a)
max_over_b = sim_a2b.max(dim=3).values
# zero-out padded a-tokens then average over valid tokens
max_over_b = max_over_b * mask_a_exp # padded a tokens get zero
denom_a = mask_a_exp.sum(dim=2).clamp(min=1).to(sim.dtype) # (B, B)
s_a2b = max_over_b.sum(dim=2) / denom_a # (B, B)
# ---- B -> A similarity (s_b2a) ----
sim_b2a = sim.clone()
invalid_a = ~mask_a_exp.unsqueeze(3).expand(-1,-1,-1,sim_b2a.size(3)) # (B, B, N_a, N_b)
sim_b2a[invalid_a] = float('-inf')
max_over_a = sim_b2a.max(dim=2).values # (B, B, N_b)
max_over_a = max_over_a * mask_b_exp
denom_b = mask_b_exp.sum(dim=2).clamp(min=1).to(sim.dtype)
s_b2a = max_over_a.sum(dim=2) / denom_b # (B, B)
# logits and loss
logits_a2b = s_a2b / temperature
logits_b2a = s_b2a / temperature
labels = torch.arange(B, device=device, dtype=torch.long)
loss_a2b = F.cross_entropy(logits_a2b, labels)
loss_b2a = F.cross_entropy(logits_b2a, labels)
return 0.5 * (loss_a2b + loss_b2a)
def global_infonce_loss(a_tokens, b_tokens, mask_a, mask_b, temperature=0.07, agg_fn="mean"):
"""
Global InfoNCE loss (CLIP-style) for modalities A and B.
Args:
a_tokens: (B, N_a, D)
b_tokens: (B, N_b, D)
mask_a: (B, N_a) bool (True = valid)
mask_b: (B, N_b) bool (True = valid)
temperature: scalar
agg_fn: "mean" | "max" | "cls" | callable -> how to aggregate tokens into one vector
Returns:
scalar loss
"""
device = a_tokens.device
B, N_a, D = a_tokens.shape
N_b = b_tokens.shape[1]
# ---- Normalize token embeddings ----
a = F.normalize(a_tokens, dim=-1)
b = F.normalize(b_tokens, dim=-1)
# ---- Aggregate per sample ----
if callable(agg_fn):
a_global = agg_fn(a, mask_a) # custom aggregation
b_global = agg_fn(b, mask_b)
elif agg_fn == "mean":
# masked mean
a_global = (a * mask_a.unsqueeze(-1)).sum(dim=1) / mask_a.sum(dim=1, keepdim=True).clamp(min=1)
b_global = (b * mask_b.unsqueeze(-1)).sum(dim=1) / mask_b.sum(dim=1, keepdim=True).clamp(min=1)
elif agg_fn == "max":
a_global = (a.masked_fill(~mask_a.unsqueeze(-1), float('-inf'))).max(dim=1).values
b_global = (b.masked_fill(~mask_b.unsqueeze(-1), float('-inf'))).max(dim=1).values
elif agg_fn == "cls":
# use first valid token as "cls"
a_global = a[:, 0, :]
b_global = b[:, 0, :]
else:
raise ValueError(f"Unknown agg_fn: {agg_fn}")
# ---- Compute cosine similarity matrix ----
a_global = F.normalize(a_global, dim=-1)
b_global = F.normalize(b_global, dim=-1)
logits = (a_global @ b_global.T) / temperature # (B, B)
# ---- InfoNCE loss ----
labels = torch.arange(B, device=device)
loss_a2b = F.cross_entropy(logits, labels)
loss_b2a = F.cross_entropy(logits.T, labels)
loss = 0.5 * (loss_a2b + loss_b2a)
return loss
# ---------- PCGrad utility ----------
def pcgrad_combine(losses, shared_params):
"""
Compute PCGrad combined gradient for a list of scalar losses.
losses: list of scalar loss tensors
shared_params: list of parameters to project/aggregate gradients for
returns: scalar combined loss for logging (mean)
"""
grads_list = [torch.autograd.grad(l, shared_params, retain_graph=True, allow_unused=True)
for l in losses]
# flatten
flat_grads = [torch.cat([g.reshape(-1) for g in grads if g is not None]) for grads in grads_list]
projected = [fg.clone() for fg in flat_grads]
# project conflicting grads
for i in range(len(flat_grads)):
for j in range(len(flat_grads)):
if i == j:
continue
dot = (projected[i] * projected[j]).sum()
if dot < 0:
proj = dot / (projected[j].norm() ** 2 + 1e-12)
projected[i] = projected[i] - proj * projected[j]
# sum projected grads
final_grad = sum(projected)
# assign to params
pointer = 0
for p in shared_params:
if p.requires_grad:
numel = p.numel()
p.grad = final_grad[pointer:pointer + numel].view_as(p).clone()
pointer += numel
# return average loss for logging only
return sum(losses) / len(losses)