|
|
--- |
|
|
license: mit |
|
|
base_model: |
|
|
- google-bert/bert-base-uncased |
|
|
--- |
|
|
|
|
|
# Dragon Context Encoder |
|
|
|
|
|
This is the **context encoder** of the Dragon dual-encoder retrieval model, trained for dense passage retrieval tasks. |
|
|
It should be used together with the corresponding [Dragon Query Encoder](https://huggingface.co/liyongkang/dragon-query-encoder). |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
- **Base model:** `bert-base-uncased` |
|
|
- **Architecture:** Dense Passage Retriever (DPR) dual-encoder |
|
|
- **Encoder type:** Context encoder (for passages) |
|
|
- **Pooling method:** CLS pooling (take `[CLS]` token representation) |
|
|
|
|
|
- **Checkpoint origin:** |
|
|
The weights were converted from the official [facebookresearch/dpr-scale Dragon implementation](https://github.com/facebookresearch/dpr-scale/tree/main/dragon), |
|
|
specifically from the checkpoint provided at: |
|
|
[https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt](https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt) |
|
|
|
|
|
|
|
|
## Usage Example |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
|
|
|
# Load query encoder |
|
|
q_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-query-encoder") |
|
|
q_model = AutoModel.from_pretrained("liyongkang/dragon-query-encoder") |
|
|
|
|
|
# Load context encoder |
|
|
p_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-context-encoder") |
|
|
p_model = AutoModel.from_pretrained("liyongkang/dragon-context-encoder") |
|
|
|
|
|
query = "What is Dragon in NLP?" |
|
|
passage = "A dual-encoder retrieval model for dense passage retrieval." |
|
|
|
|
|
|
|
|
# Tokenize. In fact, the two tokenizers are the same. |
|
|
q_inputs = q_tokenizer(query, return_tensors="pt", truncation=True, padding=True) |
|
|
p_inputs = p_tokenizer(passage, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
q_vec = q_model(**q_inputs).last_hidden_state[:, 0] # CLS pooling |
|
|
p_vec = p_model(**p_inputs).last_hidden_state[:, 0] # CLS pooling |
|
|
score = (q_vec * p_vec).sum(dim=-1) |
|
|
print("Dot product similarity:", score.item()) |