liyongkang commited on
Commit
3d9f350
·
verified ·
1 Parent(s): 6b3862a

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +52 -0
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - google-bert/bert-base-uncased
5
+ ---
6
+
7
+ # Dragon Query Encoder
8
+
9
+ This is the **query encoder** of the Dragon dual-encoder retrieval model, trained for dense passage retrieval tasks.
10
+ It should be used together with the corresponding [Dragon Context Encoder](https://huggingface.co/liyongkang/dragon-context-encoder) for end-to-end retrieval.
11
+
12
+
13
+ ## Model Architecture
14
+
15
+ - **Base model:** `bert-base-uncased`
16
+ - **Architecture:** Dense Passage Retriever (DPR) dual-encoder
17
+ - **Encoder type:** Context encoder (for passages)
18
+ - **Pooling method:** CLS pooling (take `[CLS]` token representation)
19
+
20
+ - **Checkpoint origin:**
21
+ The weights were converted from the official [facebookresearch/dpr-scale Dragon implementation](https://github.com/facebookresearch/dpr-scale/tree/main/dragon),
22
+ specifically from the checkpoint provided at:
23
+ [https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt](https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt)
24
+
25
+
26
+ ## Usage Example
27
+
28
+ ```python
29
+ from transformers import AutoTokenizer, AutoModel
30
+ import torch
31
+
32
+ # Load query encoder
33
+ q_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-query-encoder")
34
+ q_model = AutoModel.from_pretrained("liyongkang/dragon-query-encoder")
35
+
36
+ # Load context encoder
37
+ p_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-context-encoder")
38
+ p_model = AutoModel.from_pretrained("liyongkang/dragon-context-encoder")
39
+
40
+ query = "What is Dragon in NLP?"
41
+ passage = "A dual-encoder retrieval model for dense passage retrieval."
42
+
43
+
44
+ # Tokenize. In fact, the two tokenizers are the same.
45
+ q_inputs = q_tokenizer(query, return_tensors="pt", truncation=True, padding=True)
46
+ p_inputs = p_tokenizer(passage, return_tensors="pt", truncation=True, padding=True)
47
+
48
+ with torch.no_grad():
49
+ q_vec = q_model(**q_inputs).last_hidden_state[:, 0] # CLS pooling
50
+ p_vec = p_model(**p_inputs).last_hidden_state[:, 0] # CLS pooling
51
+ score = (q_vec * p_vec).sum(dim=-1)
52
+ print("Dot product similarity:", score.item())