import torch from torch import nn import torch.nn.functional as F def pad_graph_nodes(mol_enc, g_n_nodes): """ Args: mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs g_n_nodes: list[int], number of nodes per graph Returns: padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes mask: (B, max_nodes) bool tensor, True for valid nodes """ B = len(g_n_nodes) D = mol_enc.shape[1] max_nodes = max(g_n_nodes) # Create output with same requires_grad as input padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device) # Force gradient tracking by making this a non-leaf tensor padded = padded + mol_enc.new_zeros(1).requires_grad_(True) mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device) idx = 0 for i, n in enumerate(g_n_nodes): padded[i, :n] = mol_enc[idx:idx+n] mask[i, :n] = True idx += n return padded, mask # def pad_graph_nodes(mol_enc, g_n_nodes): # """ # Args: # mol_enc: 2D tensor of shape (sum_nodes, D) # Node embeddings for each molecule. # g_n_nodes: list[int] Number of nodes per graph (len = B) # Returns: # padded: (B, max_nodes, D) tensor # mask: (B, max_nodes) bool tensor, True for valid nodes # """ # # Already concatenated: shape (sum_nodes, D) # B = len(g_n_nodes) # D = mol_enc.shape[1] # max_nodes = max(g_n_nodes) # padded = mol_enc.new_zeros((B, max_nodes, D)) # mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device) # idx = 0 # for i, n in enumerate(g_n_nodes): # padded[i, :n] = mol_enc[idx:idx+n] # mask[i, :n] = True # idx += n # return padded, mask def filip_similarity_batch( image_tokens, text_tokens, mask_image, mask_text, reduction="mean", # "mean", "topk", "softmax", or "geom" k=5, temperature=0.05, eps=1e-6 ): """ Compute FILIP similarity for batches of image and text token embeddings. Args: image_tokens: (B, N_img, D) float tensor text_tokens: (B, N_text, D) float tensor mask_image: (B, N_img) bool tensor mask_text: (B, N_text) bool tensor reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom" k: int, used if reduction == "topk" temperature: float, used if reduction == "softmax" eps: float, small constant for numerical stability Returns: similarities: (B,) float tensor of similarity scores """ B, N_img, D = image_tokens.shape N_text = text_tokens.shape[1] # Normalize tokens image_norm = F.normalize(image_tokens, p=2, dim=-1) text_norm = F.normalize(text_tokens, p=2, dim=-1) # Compute cosine similarity matrices sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2)) # Expand masks mask_image_exp = mask_image.unsqueeze(2) mask_text_exp = mask_text.unsqueeze(1) valid_mask = mask_image_exp & mask_text_exp # Mask invalid positions sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf')) # Max per image/text token max_sim_img, _ = sim_matrix_masked.max(dim=2) max_sim_text, _ = sim_matrix_masked.max(dim=1) # Replace -inf with zeros max_sim_img[max_sim_img == float('-inf')] = 0 max_sim_text[max_sim_text == float('-inf')] = 0 # Helper: aggregate with chosen strategy def aggregate(max_sim, mask): count = mask.sum(dim=1).clamp(min=1).float() if reduction == "mean": return (max_sim * mask).sum(dim=1) / count elif reduction == "topk": k_eff = min(k, max_sim.size(1)) # Mask invalid tokens to large negative before topk masked_vals = max_sim.masked_fill(~mask, float('-inf')) topk_vals, _ = torch.topk(masked_vals, k_eff, dim=1) topk_vals[topk_vals == float('-inf')] = 0 return topk_vals.sum(dim=1) / k_eff elif reduction == "softmax": masked_vals = max_sim.masked_fill(~mask, float('-inf')) weights = torch.softmax(masked_vals / temperature, dim=1) weights = weights * mask weights = weights / weights.sum(dim=1, keepdim=True).clamp(min=eps) return (weights * max_sim).sum(dim=1) elif reduction == "geom": # Use log-sum-exp trick for geometric mean stability masked_vals = (max_sim * mask).clamp(min=eps) log_vals = torch.log(masked_vals) geom_mean = torch.exp((log_vals.sum(dim=1)) / count) return geom_mean else: raise ValueError(f"Unknown reduction type: {reduction}") # Aggregate both sides avg_img = aggregate(max_sim_img, mask_image) avg_text = aggregate(max_sim_text, mask_text) # Final similarity similarity = (avg_img + avg_text) / 2 return similarity def filip_similarity_single( image_tokens, text_tokens, reduction="mean", # "mean", "topk", "softmax", or "geom" k=5, temperature=0.05, eps=1e-6 ): """ Compute FILIP similarity for a single image and text pair (no masks). Args: image_tokens: (N_img, D) float tensor text_tokens: (N_text, D) float tensor reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom" k: int, used if reduction == "topk" temperature: float, used if reduction == "softmax" eps: float, small constant for numerical stability Returns: similarity: float scalar tensor """ # Normalize tokens image_norm = F.normalize(image_tokens, p=2, dim=-1) text_norm = F.normalize(text_tokens, p=2, dim=-1) # (N_img, N_text) cosine similarity matrix sim_matrix = torch.matmul(image_norm, text_norm.t()) # Max similarity for each token (image->text and text->image) max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,) max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,) # Aggregation helper def aggregate(max_sim): if reduction == "mean": return max_sim.mean() elif reduction == "topk": k_eff = min(k, max_sim.numel()) topk_vals, _ = torch.topk(max_sim, k_eff) return topk_vals.mean() elif reduction == "softmax": weights = torch.softmax(max_sim / temperature, dim=0) return (weights * max_sim).sum() elif reduction == "geom": vals = max_sim.clamp(min=eps) return torch.exp(torch.log(vals).mean()) else: raise ValueError(f"Unknown reduction type: {reduction}") # Aggregate both directions avg_img = aggregate(max_sim_img) avg_text = aggregate(max_sim_text) # Final similarity (scalar) similarity = (avg_img + avg_text) / 2 return similarity