liyongkang commited on
Commit
5adafeb
·
verified ·
1 Parent(s): c70c7a1

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +51 -0
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model:
4
+ - google-bert/bert-base-uncased
5
+ ---
6
+
7
+ # Dragon Context Encoder
8
+
9
+ This is the **context encoder** of the Dragon dual-encoder retrieval model, trained for dense passage retrieval tasks.
10
+ It should be used together with the corresponding [Dragon Query Encoder](https://huggingface.co/liyongkang/dragon-query-encoder).
11
+
12
+ ## Model Architecture
13
+
14
+ - **Base model:** `bert-base-uncased`
15
+ - **Architecture:** Dense Passage Retriever (DPR) dual-encoder
16
+ - **Encoder type:** Context encoder (for passages)
17
+ - **Pooling method:** CLS pooling (take `[CLS]` token representation)
18
+
19
+ - **Checkpoint origin:**
20
+ The weights were converted from the official [facebookresearch/dpr-scale Dragon implementation](https://github.com/facebookresearch/dpr-scale/tree/main/dragon),
21
+ specifically from the checkpoint provided at:
22
+ [https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt](https://dl.fbaipublicfiles.com/dragon/checkpoints/DRAGON/checkpoint_best.ckpt)
23
+
24
+
25
+ ## Usage Example
26
+
27
+ ```python
28
+ from transformers import AutoTokenizer, AutoModel
29
+ import torch
30
+
31
+ # Load query encoder
32
+ q_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-query-encoder")
33
+ q_model = AutoModel.from_pretrained("liyongkang/dragon-query-encoder")
34
+
35
+ # Load context encoder
36
+ p_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-context-encoder")
37
+ p_model = AutoModel.from_pretrained("liyongkang/dragon-context-encoder")
38
+
39
+ query = "What is Dragon in NLP?"
40
+ passage = "A dual-encoder retrieval model for dense passage retrieval."
41
+
42
+
43
+ # Tokenize. In fact, the two tokenizers are the same.
44
+ q_inputs = q_tokenizer(query, return_tensors="pt", truncation=True, padding=True)
45
+ p_inputs = p_tokenizer(passage, return_tensors="pt", truncation=True, padding=True)
46
+
47
+ with torch.no_grad():
48
+ q_vec = q_model(**q_inputs).last_hidden_state[:, 0] # CLS pooling
49
+ p_vec = p_model(**p_inputs).last_hidden_state[:, 0] # CLS pooling
50
+ score = (q_vec * p_vec).sum(dim=-1)
51
+ print("Dot product similarity:", score.item())