Jo Kristian Bergum
commited on
Commit
·
44cb111
1
Parent(s):
df12733
more details on colbert model
Browse files
README.md
CHANGED
|
@@ -1,8 +1,67 @@
|
|
| 1 |
# MS Marco Ranking with ColBERT on Vespa.ai
|
| 2 |
|
| 3 |
-
This is work in progress.
|
| 4 |
-
|
| 5 |
Model is based on [ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT](https://arxiv.org/abs/2004.12832).
|
| 6 |
-
This BERT model is based on
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
To use this model with vespa.ai for MS Marco Passage Ranking, see
|
| 8 |
-
[MS Marco Ranking using Vespa.ai sample app](https://github.com/vespa-engine/sample-apps/tree/master/msmarco-ranking)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# MS Marco Ranking with ColBERT on Vespa.ai
|
| 2 |
|
|
|
|
|
|
|
| 3 |
Model is based on [ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT](https://arxiv.org/abs/2004.12832).
|
| 4 |
+
This BERT model is based on [google/bert_uncased_L-8_H-512_A-8](https://huggingface.co/google/bert_uncased_L-8_H-512_A-8) and trained using the
|
| 5 |
+
original [ColBERT training routine](https://github.com/stanford-futuredata/ColBERT/).
|
| 6 |
+
The model weights have been tuned by training using the `triples.train.small.tar.gz from` [MSMARCO-Passage-Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking).
|
| 7 |
+
|
| 8 |
+
|
| 9 |
To use this model with vespa.ai for MS Marco Passage Ranking, see
|
| 10 |
+
[MS Marco Ranking using Vespa.ai sample app](https://github.com/vespa-engine/sample-apps/tree/master/msmarco-ranking).
|
| 11 |
+
|
| 12 |
+
# MS Marco Passage Ranking
|
| 13 |
+
|
| 14 |
+
| MS Marco Passage Ranking Query Set | MRR@10 ColBERT on Vespa.ai |
|
| 15 |
+
|------------------------------------|----------------|
|
| 16 |
+
| Dev | 0.354 |
|
| 17 |
+
| Eval | 0.347 |
|
| 18 |
+
|
| 19 |
+
The official baseline BM25 ranking model MRR@10 0.16 on eval and 0.167 on dev question set.
|
| 20 |
+
See [MS Marco Passage Ranking Leaderboard](https://microsoft.github.io/msmarco/).
|
| 21 |
+
|
| 22 |
+
## Export ColBERT query encoder to ONNX
|
| 23 |
+
We represent the ColBERT query encoder in the Vespa runtime, to map the textual query representation to the tensor representation. For this
|
| 24 |
+
we use Vespa's support for running ONNX models. One can use the following snippet to export the model for serving.
|
| 25 |
+
|
| 26 |
+
```python
|
| 27 |
+
from transformers import BertModel
|
| 28 |
+
from transformers import BertPreTrainedModel
|
| 29 |
+
from transformers import BertConfig
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
|
| 33 |
+
class VespaColBERT(BertPreTrainedModel):
|
| 34 |
+
|
| 35 |
+
def __init__(self,config):
|
| 36 |
+
super().__init__(config)
|
| 37 |
+
self.bert = BertModel(config)
|
| 38 |
+
self.linear = nn.Linear(config.hidden_size, 32, bias=False)
|
| 39 |
+
self.init_weights()
|
| 40 |
+
|
| 41 |
+
def forward(self, input_ids, attention_mask):
|
| 42 |
+
Q = self.bert(input_ids,attention_mask=attention_mask)[0]
|
| 43 |
+
Q = self.linear(Q)
|
| 44 |
+
return torch.nn.functional.normalize(Q, p=2, dim=2)
|
| 45 |
+
|
| 46 |
+
colbert_query_encoder = VespaColBERT.from_pretrained("vespa-engine/colbert-medium")
|
| 47 |
+
|
| 48 |
+
#Export model to ONNX for serving in Vespa
|
| 49 |
+
|
| 50 |
+
input_names = ["input_ids", "attention_mask"]
|
| 51 |
+
output_names = ["contextual"]
|
| 52 |
+
#input, max 32 query term
|
| 53 |
+
input_ids = torch.ones(1,32, dtype=torch.int64)
|
| 54 |
+
attention_mask = torch.ones(1,32,dtype=torch.int64)
|
| 55 |
+
args = (input_ids, attention_mask)
|
| 56 |
+
torch.onnx.export(colbert_query_encoder,
|
| 57 |
+
args=args,
|
| 58 |
+
f="query_encoder_colbert.onnx",
|
| 59 |
+
input_names = input_names,
|
| 60 |
+
output_names = output_names,
|
| 61 |
+
opset_version=11)
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
# Representing the model on Vespa.ai
|
| 66 |
+
See [Ranking with ONNX models](https://docs.vespa.ai/documentation/onnx.html) and [MS Marco Ranking sample app](https://github.com/vespa-engine/sample-apps/tree/master/msmarco-ranking)
|
| 67 |
+
|