Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import hashlib | |
| import torch | |
| def file_hash(file): | |
| # Ref: https://stackoverflow.com/a/59056837 | |
| with open(file, "rb") as f: | |
| hash_fn = hashlib.blake2b() | |
| chunk = f.read(8192) | |
| while chunk: | |
| hash_fn.update(chunk) | |
| chunk = f.read(8192) | |
| return hash_fn.hexdigest() | |
| def refine_cosine(Xa, Xq, I, device, k=None): | |
| if k is not None: | |
| assert k <= I.shape[1] | |
| else: | |
| k = I.shape[1] | |
| Xi = torch.tensor(Xq, device=device).unsqueeze(1) # bs x 1 x d | |
| Xj = torch.tensor(Xa[I.flatten()], device=device) # K * bs x d | |
| Xj = Xj.reshape(*I.shape, Xq.shape[-1]) # bs x K x d | |
| sim = torch.sum(Xi * Xj, dim=-1) # bs x K | |
| sort_idx = torch.argsort(sim, dim=1, descending=True).cpu().numpy() | |
| I_refined, S_refined = [], [] | |
| for idx_i, sim_i, sort_i in zip(I, sim.cpu().numpy(), sort_idx): | |
| I_refined.append(idx_i[sort_i][:k]) | |
| S_refined.append(sim_i[sort_i][:k]) | |
| I_refined = np.stack(I_refined) | |
| S_refined = np.stack(S_refined) | |
| return S_refined, I_refined | |