FLARE / flare /utils /general.py
yzhouchen001's picture
update
19a4dfc
import torch
from torch import nn
import torch.nn.functional as F
def pad_graph_nodes(mol_enc, g_n_nodes):
"""
Args:
mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs
g_n_nodes: list[int], number of nodes per graph
Returns:
padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes
mask: (B, max_nodes) bool tensor, True for valid nodes
"""
B = len(g_n_nodes)
D = mol_enc.shape[1]
max_nodes = max(g_n_nodes)
# Create output with same requires_grad as input
padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device)
# Force gradient tracking by making this a non-leaf tensor
padded = padded + mol_enc.new_zeros(1).requires_grad_(True)
mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device)
idx = 0
for i, n in enumerate(g_n_nodes):
padded[i, :n] = mol_enc[idx:idx+n]
mask[i, :n] = True
idx += n
return padded, mask
# def pad_graph_nodes(mol_enc, g_n_nodes):
# """
# Args:
# mol_enc: 2D tensor of shape (sum_nodes, D)
# Node embeddings for each molecule.
# g_n_nodes: list[int] Number of nodes per graph (len = B)
# Returns:
# padded: (B, max_nodes, D) tensor
# mask: (B, max_nodes) bool tensor, True for valid nodes
# """
# # Already concatenated: shape (sum_nodes, D)
# B = len(g_n_nodes)
# D = mol_enc.shape[1]
# max_nodes = max(g_n_nodes)
# padded = mol_enc.new_zeros((B, max_nodes, D))
# mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
# idx = 0
# for i, n in enumerate(g_n_nodes):
# padded[i, :n] = mol_enc[idx:idx+n]
# mask[i, :n] = True
# idx += n
# return padded, mask
def filip_similarity_batch(
image_tokens,
text_tokens,
mask_image,
mask_text,
reduction="mean", # "mean", "topk", "softmax", or "geom"
k=5,
temperature=0.05,
eps=1e-6
):
"""
Compute FILIP similarity for batches of image and text token embeddings.
Args:
image_tokens: (B, N_img, D) float tensor
text_tokens: (B, N_text, D) float tensor
mask_image: (B, N_img) bool tensor
mask_text: (B, N_text) bool tensor
reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
k: int, used if reduction == "topk"
temperature: float, used if reduction == "softmax"
eps: float, small constant for numerical stability
Returns:
similarities: (B,) float tensor of similarity scores
"""
B, N_img, D = image_tokens.shape
N_text = text_tokens.shape[1]
# Normalize tokens
image_norm = F.normalize(image_tokens, p=2, dim=-1)
text_norm = F.normalize(text_tokens, p=2, dim=-1)
# Compute cosine similarity matrices
sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
# Expand masks
mask_image_exp = mask_image.unsqueeze(2)
mask_text_exp = mask_text.unsqueeze(1)
valid_mask = mask_image_exp & mask_text_exp
# Mask invalid positions
sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
# Max per image/text token
max_sim_img, _ = sim_matrix_masked.max(dim=2)
max_sim_text, _ = sim_matrix_masked.max(dim=1)
# Replace -inf with zeros
max_sim_img[max_sim_img == float('-inf')] = 0
max_sim_text[max_sim_text == float('-inf')] = 0
# Helper: aggregate with chosen strategy
def aggregate(max_sim, mask):
count = mask.sum(dim=1).clamp(min=1).float()
if reduction == "mean":
return (max_sim * mask).sum(dim=1) / count
elif reduction == "topk":
k_eff = min(k, max_sim.size(1))
# Mask invalid tokens to large negative before topk
masked_vals = max_sim.masked_fill(~mask, float('-inf'))
topk_vals, _ = torch.topk(masked_vals, k_eff, dim=1)
topk_vals[topk_vals == float('-inf')] = 0
return topk_vals.sum(dim=1) / k_eff
elif reduction == "softmax":
masked_vals = max_sim.masked_fill(~mask, float('-inf'))
weights = torch.softmax(masked_vals / temperature, dim=1)
weights = weights * mask
weights = weights / weights.sum(dim=1, keepdim=True).clamp(min=eps)
return (weights * max_sim).sum(dim=1)
elif reduction == "geom":
# Use log-sum-exp trick for geometric mean stability
masked_vals = (max_sim * mask).clamp(min=eps)
log_vals = torch.log(masked_vals)
geom_mean = torch.exp((log_vals.sum(dim=1)) / count)
return geom_mean
else:
raise ValueError(f"Unknown reduction type: {reduction}")
# Aggregate both sides
avg_img = aggregate(max_sim_img, mask_image)
avg_text = aggregate(max_sim_text, mask_text)
# Final similarity
similarity = (avg_img + avg_text) / 2
return similarity
def filip_similarity_single(
image_tokens,
text_tokens,
reduction="mean", # "mean", "topk", "softmax", or "geom"
k=5,
temperature=0.05,
eps=1e-6
):
"""
Compute FILIP similarity for a single image and text pair (no masks).
Args:
image_tokens: (N_img, D) float tensor
text_tokens: (N_text, D) float tensor
reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
k: int, used if reduction == "topk"
temperature: float, used if reduction == "softmax"
eps: float, small constant for numerical stability
Returns:
similarity: float scalar tensor
"""
# Normalize tokens
image_norm = F.normalize(image_tokens, p=2, dim=-1)
text_norm = F.normalize(text_tokens, p=2, dim=-1)
# (N_img, N_text) cosine similarity matrix
sim_matrix = torch.matmul(image_norm, text_norm.t())
# Max similarity for each token (image->text and text->image)
max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,)
max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,)
# Aggregation helper
def aggregate(max_sim):
if reduction == "mean":
return max_sim.mean()
elif reduction == "topk":
k_eff = min(k, max_sim.numel())
topk_vals, _ = torch.topk(max_sim, k_eff)
return topk_vals.mean()
elif reduction == "softmax":
weights = torch.softmax(max_sim / temperature, dim=0)
return (weights * max_sim).sum()
elif reduction == "geom":
vals = max_sim.clamp(min=eps)
return torch.exp(torch.log(vals).mean())
else:
raise ValueError(f"Unknown reduction type: {reduction}")
# Aggregate both directions
avg_img = aggregate(max_sim_img)
avg_text = aggregate(max_sim_text)
# Final similarity (scalar)
similarity = (avg_img + avg_text) / 2
return similarity