| import torch | |
| from colbert.modeling.colbert import ColBERT | |
| from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer | |
| from colbert.utils.amp import MixedPrecisionManager | |
| from colbert.parameters import DEVICE | |
| class ModelInference(): | |
| def __init__(self, colbert: ColBERT, amp=False): | |
| assert colbert.training is False | |
| self.colbert = colbert | |
| self.query_tokenizer = QueryTokenizer(colbert.query_maxlen) | |
| self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen) | |
| self.amp_manager = MixedPrecisionManager(amp) | |
| def query(self, *args, to_cpu=False, **kw_args): | |
| with torch.no_grad(): | |
| with self.amp_manager.context(): | |
| Q = self.colbert.query(*args, **kw_args) | |
| return Q.cpu() if to_cpu else Q | |
| def doc(self, *args, to_cpu=False, **kw_args): | |
| with torch.no_grad(): | |
| with self.amp_manager.context(): | |
| D = self.colbert.doc(*args, **kw_args) | |
| return D.cpu() if to_cpu else D | |
| def queryFromText(self, queries, bsize=None, to_cpu=False): | |
| if bsize: | |
| batches = self.query_tokenizer.tensorize(queries, bsize=bsize) | |
| batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches] | |
| return torch.cat(batches) | |
| input_ids, attention_mask = self.query_tokenizer.tensorize(queries) | |
| return self.query(input_ids, attention_mask) | |
| def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False): | |
| if bsize: | |
| batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize) | |
| batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu) | |
| for input_ids, attention_mask in batches] | |
| if keep_dims: | |
| D = _stack_3D_tensors(batches) | |
| return D[reverse_indices] | |
| D = [d for batch in batches for d in batch] | |
| return [D[idx] for idx in reverse_indices.tolist()] | |
| input_ids, attention_mask = self.doc_tokenizer.tensorize(docs) | |
| return self.doc(input_ids, attention_mask, keep_dims=keep_dims) | |
| def score(self, Q, D, mask=None, lengths=None, explain=False): | |
| if lengths is not None: | |
| assert mask is None, "don't supply both mask and lengths" | |
| mask = torch.arange(D.size(1), device=DEVICE) + 1 | |
| mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1) | |
| scores = (D @ Q) | |
| scores = scores if mask is None else scores * mask.unsqueeze(-1) | |
| scores = scores.max(1) | |
| if explain: | |
| assert False, "TODO" | |
| return scores.values.sum(-1).cpu() | |
| def _stack_3D_tensors(groups): | |
| bsize = sum([x.size(0) for x in groups]) | |
| maxlen = max([x.size(1) for x in groups]) | |
| hdim = groups[0].size(2) | |
| output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype) | |
| offset = 0 | |
| for x in groups: | |
| endpos = offset + x.size(0) | |
| output[offset:endpos, :x.size(1)] = x | |
| offset = endpos | |
| return output | |