liyongkang's picture
Create README.md
5adafeb verified
metadata
license: mit
base_model:
  - google-bert/bert-base-uncased

Dragon Context Encoder

This is the context encoder of the Dragon dual-encoder retrieval model, trained for dense passage retrieval tasks.
It should be used together with the corresponding Dragon Query Encoder.

Model Architecture

Usage Example

from transformers import AutoTokenizer, AutoModel
import torch

# Load query encoder
q_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-query-encoder")
q_model = AutoModel.from_pretrained("liyongkang/dragon-query-encoder")

# Load context encoder
p_tokenizer = AutoTokenizer.from_pretrained("liyongkang/dragon-context-encoder")
p_model = AutoModel.from_pretrained("liyongkang/dragon-context-encoder")

query = "What is Dragon in NLP?"
passage = "A dual-encoder retrieval model for dense passage retrieval."


# Tokenize. In fact, the two tokenizers are the same.
q_inputs = q_tokenizer(query, return_tensors="pt", truncation=True, padding=True)
p_inputs = p_tokenizer(passage, return_tensors="pt", truncation=True, padding=True)

with torch.no_grad():
    q_vec = q_model(**q_inputs).last_hidden_state[:, 0]  # CLS pooling
    p_vec = p_model(**p_inputs).last_hidden_state[:, 0]  # CLS pooling
    score = (q_vec * p_vec).sum(dim=-1)
    print("Dot product similarity:", score.item())