Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| def pad_graph_nodes(mol_enc, g_n_nodes): | |
| """ | |
| Args: | |
| mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs | |
| g_n_nodes: list[int], number of nodes per graph | |
| Returns: | |
| padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes | |
| mask: (B, max_nodes) bool tensor, True for valid nodes | |
| """ | |
| B = len(g_n_nodes) | |
| D = mol_enc.shape[1] | |
| max_nodes = max(g_n_nodes) | |
| # Create output with same requires_grad as input | |
| padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device) | |
| # Force gradient tracking by making this a non-leaf tensor | |
| padded = padded + mol_enc.new_zeros(1).requires_grad_(True) | |
| mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device) | |
| idx = 0 | |
| for i, n in enumerate(g_n_nodes): | |
| padded[i, :n] = mol_enc[idx:idx+n] | |
| mask[i, :n] = True | |
| idx += n | |
| return padded, mask | |
| # def pad_graph_nodes(mol_enc, g_n_nodes): | |
| # """ | |
| # Args: | |
| # mol_enc: 2D tensor of shape (sum_nodes, D) | |
| # Node embeddings for each molecule. | |
| # g_n_nodes: list[int] Number of nodes per graph (len = B) | |
| # Returns: | |
| # padded: (B, max_nodes, D) tensor | |
| # mask: (B, max_nodes) bool tensor, True for valid nodes | |
| # """ | |
| # # Already concatenated: shape (sum_nodes, D) | |
| # B = len(g_n_nodes) | |
| # D = mol_enc.shape[1] | |
| # max_nodes = max(g_n_nodes) | |
| # padded = mol_enc.new_zeros((B, max_nodes, D)) | |
| # mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device) | |
| # idx = 0 | |
| # for i, n in enumerate(g_n_nodes): | |
| # padded[i, :n] = mol_enc[idx:idx+n] | |
| # mask[i, :n] = True | |
| # idx += n | |
| # return padded, mask | |
| def filip_similarity_batch( | |
| image_tokens, | |
| text_tokens, | |
| mask_image, | |
| mask_text, | |
| reduction="mean", # "mean", "topk", "softmax", or "geom" | |
| k=5, | |
| temperature=0.05, | |
| eps=1e-6 | |
| ): | |
| """ | |
| Compute FILIP similarity for batches of image and text token embeddings. | |
| Args: | |
| image_tokens: (B, N_img, D) float tensor | |
| text_tokens: (B, N_text, D) float tensor | |
| mask_image: (B, N_img) bool tensor | |
| mask_text: (B, N_text) bool tensor | |
| reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom" | |
| k: int, used if reduction == "topk" | |
| temperature: float, used if reduction == "softmax" | |
| eps: float, small constant for numerical stability | |
| Returns: | |
| similarities: (B,) float tensor of similarity scores | |
| """ | |
| B, N_img, D = image_tokens.shape | |
| N_text = text_tokens.shape[1] | |
| # Normalize tokens | |
| image_norm = F.normalize(image_tokens, p=2, dim=-1) | |
| text_norm = F.normalize(text_tokens, p=2, dim=-1) | |
| # Compute cosine similarity matrices | |
| sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2)) | |
| # Expand masks | |
| mask_image_exp = mask_image.unsqueeze(2) | |
| mask_text_exp = mask_text.unsqueeze(1) | |
| valid_mask = mask_image_exp & mask_text_exp | |
| # Mask invalid positions | |
| sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf')) | |
| # Max per image/text token | |
| max_sim_img, _ = sim_matrix_masked.max(dim=2) | |
| max_sim_text, _ = sim_matrix_masked.max(dim=1) | |
| # Replace -inf with zeros | |
| max_sim_img[max_sim_img == float('-inf')] = 0 | |
| max_sim_text[max_sim_text == float('-inf')] = 0 | |
| # Helper: aggregate with chosen strategy | |
| def aggregate(max_sim, mask): | |
| count = mask.sum(dim=1).clamp(min=1).float() | |
| if reduction == "mean": | |
| return (max_sim * mask).sum(dim=1) / count | |
| elif reduction == "topk": | |
| k_eff = min(k, max_sim.size(1)) | |
| # Mask invalid tokens to large negative before topk | |
| masked_vals = max_sim.masked_fill(~mask, float('-inf')) | |
| topk_vals, _ = torch.topk(masked_vals, k_eff, dim=1) | |
| topk_vals[topk_vals == float('-inf')] = 0 | |
| return topk_vals.sum(dim=1) / k_eff | |
| elif reduction == "softmax": | |
| masked_vals = max_sim.masked_fill(~mask, float('-inf')) | |
| weights = torch.softmax(masked_vals / temperature, dim=1) | |
| weights = weights * mask | |
| weights = weights / weights.sum(dim=1, keepdim=True).clamp(min=eps) | |
| return (weights * max_sim).sum(dim=1) | |
| elif reduction == "geom": | |
| # Use log-sum-exp trick for geometric mean stability | |
| masked_vals = (max_sim * mask).clamp(min=eps) | |
| log_vals = torch.log(masked_vals) | |
| geom_mean = torch.exp((log_vals.sum(dim=1)) / count) | |
| return geom_mean | |
| else: | |
| raise ValueError(f"Unknown reduction type: {reduction}") | |
| # Aggregate both sides | |
| avg_img = aggregate(max_sim_img, mask_image) | |
| avg_text = aggregate(max_sim_text, mask_text) | |
| # Final similarity | |
| similarity = (avg_img + avg_text) / 2 | |
| return similarity | |
| def filip_similarity_single( | |
| image_tokens, | |
| text_tokens, | |
| reduction="mean", # "mean", "topk", "softmax", or "geom" | |
| k=5, | |
| temperature=0.05, | |
| eps=1e-6 | |
| ): | |
| """ | |
| Compute FILIP similarity for a single image and text pair (no masks). | |
| Args: | |
| image_tokens: (N_img, D) float tensor | |
| text_tokens: (N_text, D) float tensor | |
| reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom" | |
| k: int, used if reduction == "topk" | |
| temperature: float, used if reduction == "softmax" | |
| eps: float, small constant for numerical stability | |
| Returns: | |
| similarity: float scalar tensor | |
| """ | |
| # Normalize tokens | |
| image_norm = F.normalize(image_tokens, p=2, dim=-1) | |
| text_norm = F.normalize(text_tokens, p=2, dim=-1) | |
| # (N_img, N_text) cosine similarity matrix | |
| sim_matrix = torch.matmul(image_norm, text_norm.t()) | |
| # Max similarity for each token (image->text and text->image) | |
| max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,) | |
| max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,) | |
| # Aggregation helper | |
| def aggregate(max_sim): | |
| if reduction == "mean": | |
| return max_sim.mean() | |
| elif reduction == "topk": | |
| k_eff = min(k, max_sim.numel()) | |
| topk_vals, _ = torch.topk(max_sim, k_eff) | |
| return topk_vals.mean() | |
| elif reduction == "softmax": | |
| weights = torch.softmax(max_sim / temperature, dim=0) | |
| return (weights * max_sim).sum() | |
| elif reduction == "geom": | |
| vals = max_sim.clamp(min=eps) | |
| return torch.exp(torch.log(vals).mean()) | |
| else: | |
| raise ValueError(f"Unknown reduction type: {reduction}") | |
| # Aggregate both directions | |
| avg_img = aggregate(max_sim_img) | |
| avg_text = aggregate(max_sim_text) | |
| # Final similarity (scalar) | |
| similarity = (avg_img + avg_text) / 2 | |
| return similarity | |