| """ |
| This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase. |
| """ |
|
|
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def get_rank(): |
| return dist.get_rank() |
|
|
|
|
| def get_world_size(): |
| return dist.get_world_size() |
|
|
|
|
| def get_default_group(): |
| return dist.group.WORLD |
|
|
|
|
| |
| def colbert_score_reduce(scores_padded, D_mask): |
| |
| D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() |
| |
| |
| scores_padded[D_padding] = -9999 |
| scores = scores_padded.max(1).values |
|
|
| return scores.sum(-1) |
|
|
|
|
| def colbert_score(Q, D_padded, D_mask, use_gpu=False): |
| """ |
| Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim). |
| If Q.size(0) is 1, the matrix will be compared with all passages. |
| Otherwise, each query matrix will be compared against the *aligned* passage. |
| |
| EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU). |
| """ |
| if use_gpu: |
| Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda() |
| assert Q.dim() == 3, Q.size() |
| assert D_padded.dim() == 3, D_padded.size() |
| assert Q.size(0) in [1, D_padded.size(0)] |
|
|
| scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) |
|
|
| return colbert_score_reduce(scores, D_mask) |
|
|
|
|
| def _sort_by_length(ids, mask, bsize, *args): |
| if ids.size(0) <= bsize: |
| return ids, mask, torch.arange(ids.size(0)) |
|
|
| indices = mask.sum(-1).sort().indices |
| reverse_indices = indices.sort().indices |
|
|
| return_array = [ids[indices], mask[indices]] |
| for arg in args: |
| if isinstance(arg, torch.Tensor): |
| return_array.append(arg[indices]) |
| else: |
| |
| return_array.append([arg[i] for i in indices]) |
|
|
| return *return_array, reverse_indices |
|
|
|
|
| def _split_into_batches(ids, mask, bsize, *args): |
| batches = [] |
| for offset in range(0, ids.size(0), bsize): |
| batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]] |
| for arg in args: |
| batch.append(arg[offset : offset + bsize]) |
| batches.append(batch) |
| return batches |
|
|