import numpy as np import hashlib import torch def file_hash(file): # Ref: https://stackoverflow.com/a/59056837 with open(file, "rb") as f: hash_fn = hashlib.blake2b() chunk = f.read(8192) while chunk: hash_fn.update(chunk) chunk = f.read(8192) return hash_fn.hexdigest() @torch.no_grad() def refine_cosine(Xa, Xq, I, device, k=None): if k is not None: assert k <= I.shape[1] else: k = I.shape[1] Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d sim = torch.sum(Xi * Xj, dim=-1) # bs x K sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy() I_refined, S_refined = [], [] for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx): I_refined.append(idx_i[sort_i][:k]) S_refined.append(sim_i[sort_i][:k]) I_refined = np.stack(I_refined) S_refined = np.stack(S_refined) return S_refined, I_refined