Spaces:
Sleeping
Sleeping
File size: 7,056 Bytes
2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc 2c0063e 19a4dfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch
from torch import nn
import torch.nn.functional as F
def pad_graph_nodes(mol_enc, g_n_nodes):
"""
Args:
mol_enc: (sum_nodes, D) tensor, node embeddings concatenated for all graphs
g_n_nodes: list[int], number of nodes per graph
Returns:
padded: (B, max_nodes, D) tensor with requires_grad=True for original nodes
mask: (B, max_nodes) bool tensor, True for valid nodes
"""
B = len(g_n_nodes)
D = mol_enc.shape[1]
max_nodes = max(g_n_nodes)
# Create output with same requires_grad as input
padded = torch.zeros(B, max_nodes, D, dtype=mol_enc.dtype, device=mol_enc.device)
# Force gradient tracking by making this a non-leaf tensor
padded = padded + mol_enc.new_zeros(1).requires_grad_(True)
mask = torch.zeros(B, max_nodes, dtype=torch.bool, device=mol_enc.device)
idx = 0
for i, n in enumerate(g_n_nodes):
padded[i, :n] = mol_enc[idx:idx+n]
mask[i, :n] = True
idx += n
return padded, mask
# def pad_graph_nodes(mol_enc, g_n_nodes):
# """
# Args:
# mol_enc: 2D tensor of shape (sum_nodes, D)
# Node embeddings for each molecule.
# g_n_nodes: list[int] Number of nodes per graph (len = B)
# Returns:
# padded: (B, max_nodes, D) tensor
# mask: (B, max_nodes) bool tensor, True for valid nodes
# """
# # Already concatenated: shape (sum_nodes, D)
# B = len(g_n_nodes)
# D = mol_enc.shape[1]
# max_nodes = max(g_n_nodes)
# padded = mol_enc.new_zeros((B, max_nodes, D))
# mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)
# idx = 0
# for i, n in enumerate(g_n_nodes):
# padded[i, :n] = mol_enc[idx:idx+n]
# mask[i, :n] = True
# idx += n
# return padded, mask
def filip_similarity_batch(
image_tokens,
text_tokens,
mask_image,
mask_text,
reduction="mean", # "mean", "topk", "softmax", or "geom"
k=5,
temperature=0.05,
eps=1e-6
):
"""
Compute FILIP similarity for batches of image and text token embeddings.
Args:
image_tokens: (B, N_img, D) float tensor
text_tokens: (B, N_text, D) float tensor
mask_image: (B, N_img) bool tensor
mask_text: (B, N_text) bool tensor
reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
k: int, used if reduction == "topk"
temperature: float, used if reduction == "softmax"
eps: float, small constant for numerical stability
Returns:
similarities: (B,) float tensor of similarity scores
"""
B, N_img, D = image_tokens.shape
N_text = text_tokens.shape[1]
# Normalize tokens
image_norm = F.normalize(image_tokens, p=2, dim=-1)
text_norm = F.normalize(text_tokens, p=2, dim=-1)
# Compute cosine similarity matrices
sim_matrix = torch.bmm(image_norm, text_norm.transpose(1, 2))
# Expand masks
mask_image_exp = mask_image.unsqueeze(2)
mask_text_exp = mask_text.unsqueeze(1)
valid_mask = mask_image_exp & mask_text_exp
# Mask invalid positions
sim_matrix_masked = sim_matrix.masked_fill(~valid_mask, float('-inf'))
# Max per image/text token
max_sim_img, _ = sim_matrix_masked.max(dim=2)
max_sim_text, _ = sim_matrix_masked.max(dim=1)
# Replace -inf with zeros
max_sim_img[max_sim_img == float('-inf')] = 0
max_sim_text[max_sim_text == float('-inf')] = 0
# Helper: aggregate with chosen strategy
def aggregate(max_sim, mask):
count = mask.sum(dim=1).clamp(min=1).float()
if reduction == "mean":
return (max_sim * mask).sum(dim=1) / count
elif reduction == "topk":
k_eff = min(k, max_sim.size(1))
# Mask invalid tokens to large negative before topk
masked_vals = max_sim.masked_fill(~mask, float('-inf'))
topk_vals, _ = torch.topk(masked_vals, k_eff, dim=1)
topk_vals[topk_vals == float('-inf')] = 0
return topk_vals.sum(dim=1) / k_eff
elif reduction == "softmax":
masked_vals = max_sim.masked_fill(~mask, float('-inf'))
weights = torch.softmax(masked_vals / temperature, dim=1)
weights = weights * mask
weights = weights / weights.sum(dim=1, keepdim=True).clamp(min=eps)
return (weights * max_sim).sum(dim=1)
elif reduction == "geom":
# Use log-sum-exp trick for geometric mean stability
masked_vals = (max_sim * mask).clamp(min=eps)
log_vals = torch.log(masked_vals)
geom_mean = torch.exp((log_vals.sum(dim=1)) / count)
return geom_mean
else:
raise ValueError(f"Unknown reduction type: {reduction}")
# Aggregate both sides
avg_img = aggregate(max_sim_img, mask_image)
avg_text = aggregate(max_sim_text, mask_text)
# Final similarity
similarity = (avg_img + avg_text) / 2
return similarity
def filip_similarity_single(
image_tokens,
text_tokens,
reduction="mean", # "mean", "topk", "softmax", or "geom"
k=5,
temperature=0.05,
eps=1e-6
):
"""
Compute FILIP similarity for a single image and text pair (no masks).
Args:
image_tokens: (N_img, D) float tensor
text_tokens: (N_text, D) float tensor
reduction: str, aggregation strategy: "mean", "topk", "softmax", or "geom"
k: int, used if reduction == "topk"
temperature: float, used if reduction == "softmax"
eps: float, small constant for numerical stability
Returns:
similarity: float scalar tensor
"""
# Normalize tokens
image_norm = F.normalize(image_tokens, p=2, dim=-1)
text_norm = F.normalize(text_tokens, p=2, dim=-1)
# (N_img, N_text) cosine similarity matrix
sim_matrix = torch.matmul(image_norm, text_norm.t())
# Max similarity for each token (image->text and text->image)
max_sim_img, _ = sim_matrix.max(dim=1) # (N_img,)
max_sim_text, _ = sim_matrix.max(dim=0) # (N_text,)
# Aggregation helper
def aggregate(max_sim):
if reduction == "mean":
return max_sim.mean()
elif reduction == "topk":
k_eff = min(k, max_sim.numel())
topk_vals, _ = torch.topk(max_sim, k_eff)
return topk_vals.mean()
elif reduction == "softmax":
weights = torch.softmax(max_sim / temperature, dim=0)
return (weights * max_sim).sum()
elif reduction == "geom":
vals = max_sim.clamp(min=eps)
return torch.exp(torch.log(vals).mean())
else:
raise ValueError(f"Unknown reduction type: {reduction}")
# Aggregate both directions
avg_img = aggregate(max_sim_img)
avg_text = aggregate(max_sim_text)
# Final similarity (scalar)
similarity = (avg_img + avg_text) / 2
return similarity
|