| | """ |
| | This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase. |
| | """ |
| |
|
| | import torch |
| | import torch.distributed as dist |
| |
|
| |
|
| | def get_rank(): |
| | return dist.get_rank() |
| |
|
| |
|
| | def get_world_size(): |
| | return dist.get_world_size() |
| |
|
| |
|
| | def get_default_group(): |
| | return dist.group.WORLD |
| |
|
| |
|
| | |
| | def colbert_score_reduce(scores_padded, D_mask): |
| | |
| | D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() |
| | |
| | |
| | scores_padded[D_padding] = -9999 |
| | scores = scores_padded.max(1).values |
| |
|
| | return scores.sum(-1) |
| |
|
| |
|
| | def colbert_score(Q, D_padded, D_mask, use_gpu=False): |
| | """ |
| | Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim). |
| | If Q.size(0) is 1, the matrix will be compared with all passages. |
| | Otherwise, each query matrix will be compared against the *aligned* passage. |
| | |
| | EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU). |
| | """ |
| | if use_gpu: |
| | Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda() |
| | assert Q.dim() == 3, Q.size() |
| | assert D_padded.dim() == 3, D_padded.size() |
| | assert Q.size(0) in [1, D_padded.size(0)] |
| |
|
| | scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) |
| |
|
| | return colbert_score_reduce(scores, D_mask) |
| |
|
| |
|
| | def _sort_by_length(ids, mask, bsize, *args): |
| | if ids.size(0) <= bsize: |
| | return ids, mask, torch.arange(ids.size(0)) |
| |
|
| | indices = mask.sum(-1).sort().indices |
| | reverse_indices = indices.sort().indices |
| |
|
| | return_array = [ids[indices], mask[indices]] |
| | for arg in args: |
| | if isinstance(arg, torch.Tensor): |
| | return_array.append(arg[indices]) |
| | else: |
| | |
| | return_array.append([arg[i] for i in indices]) |
| |
|
| | return *return_array, reverse_indices |
| |
|
| |
|
| | def _split_into_batches(ids, mask, bsize, *args): |
| | batches = [] |
| | for offset in range(0, ids.size(0), bsize): |
| | batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]] |
| | for arg in args: |
| | batch.append(arg[offset : offset + bsize]) |
| | batches.append(batch) |
| | return batches |
| |
|