| import math |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from datasets import load_dataset |
| import faiss |
| import gradio as gr |
|
|
| from transformers import ( |
| DPRQuestionEncoder, DPRContextEncoder, |
| DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast |
| ) |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| |
| |
| q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained( |
| "facebook/dpr-question_encoder-single-nq-base" |
| ) |
| p_tok = DPRContextEncoderTokenizerFast.from_pretrained( |
| "facebook/dpr-ctx_encoder-single-nq-base" |
| ) |
|
|
| |
| |
| |
| class DPRBiEncoderHF(nn.Module): |
| def __init__(self, proj_dim=768, init_tau=0.07, freeze_backbones=False, use_projection=True): |
| super().__init__() |
| self.q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") |
| self.p_encoder = DPRContextEncoder .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") |
|
|
| if freeze_backbones: |
| for p in self.q_encoder.parameters(): p.requires_grad = False |
| for p in self.p_encoder.parameters(): p.requires_grad = False |
|
|
| self.use_projection = use_projection |
| hidden = self.q_encoder.config.hidden_size |
| if use_projection: |
| self.q_proj = nn.Linear(hidden, proj_dim, bias=False) |
| self.p_proj = nn.Linear(hidden, proj_dim, bias=False) |
| out_dim = proj_dim |
| else: |
| self.q_proj = nn.Identity() |
| self.p_proj = nn.Identity() |
| out_dim = hidden |
|
|
| |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32)) |
| self.out_dim = out_dim |
|
|
| def _norm(self, x): |
| return F.normalize(x, dim=-1) |
|
|
| @torch.no_grad() |
| def encode_questions(self, **q_inputs): |
| out = self.q_encoder(**q_inputs, return_dict=True) |
| h = out.pooler_output |
| z = self.q_proj(h) |
| return self._norm(z) |
|
|
| @torch.no_grad() |
| def encode_passages(self, **p_inputs): |
| out = self.p_encoder(**p_inputs, return_dict=True) |
| h = out.pooler_output |
| z = self.p_proj(h) |
| return self._norm(z) |
|
|
| |
| |
| |
| def load_model(weights_path="dpr_biencoder_hf.pt"): |
| model = DPRBiEncoderHF( |
| proj_dim=768, |
| init_tau=0.07, |
| freeze_backbones=False, |
| use_projection=True |
| ).to(device) |
| model.load_state_dict(torch.load(weights_path, map_location=device)) |
| model.eval() |
| return model |
|
|
| model = load_model("dpr_biencoder_hf.pt") |
|
|
| |
| |
| |
| def build_corpus(split="validation[:5%]"): |
| ds = load_dataset("squad", split=split) |
| uniq = list(dict.fromkeys(ds["context"])) |
| corpus = [{"id": i, "title": "", "text": ctx} for i, ctx in enumerate(uniq)] |
| return corpus |
|
|
| def batch_p(titles, texts, max_len=256): |
| enc = p_tok(text=titles, text_pair=texts, return_tensors="pt", |
| padding=True, truncation=True, max_length=max_len) |
| return {k: v.to(device) for k, v in enc.items()} |
|
|
| def batch_q(questions, max_len=64): |
| enc = q_tok(questions, return_tensors="pt", padding=True, truncation=True, max_length=max_len) |
| return {k: v.to(device) for k, v in enc.items()} |
|
|
| @torch.no_grad() |
| def build_faiss_index(model, corpus, batch_size=64, p_max=256): |
| all_vecs = [] |
| for i in range(0, len(corpus), batch_size): |
| titles = [c["title"] for c in corpus[i:i+batch_size]] |
| texts = [c["text"] for c in corpus[i:i+batch_size]] |
| zp = model.encode_passages(**batch_p(titles, texts, p_max)) |
| all_vecs.append(zp.cpu().numpy().astype("float32")) |
| doc_embs = np.vstack(all_vecs) |
|
|
| dim = doc_embs.shape[1] |
| index = faiss.IndexFlatIP(dim) |
| index.add(doc_embs) |
| return index, corpus |
|
|
| |
| corpus = build_corpus("validation[:5%]") |
| index, corpus = build_faiss_index(model, corpus) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def encode_question(model, question, q_max=64): |
| zq = model.encode_questions(**batch_q([question], q_max)) |
| return zq.cpu().numpy().astype("float32") |
|
|
| def retrieve(question, k=5): |
| q = encode_question(model, question) |
| scores, idxs = index.search(q, k) |
| idxs, scores = idxs[0], scores[0] |
| |
| lines = [] |
| for rank, (row, score) in enumerate(zip(idxs, scores), 1): |
| doc = corpus[int(row)] |
| snippet = doc["text"].replace("\n", " ") |
| if len(snippet) > 700: |
| snippet = snippet[:700] + "…" |
| lines.append(f"{rank}. score={score:.3f} | id={doc['id']}\n {snippet}") |
| return "\n\n".join(lines) |
|
|
| |
| |
| |
| css = """ |
| .gradio-container {max-width: 1200px !important;} |
| .footer {visibility: hidden;} |
| """ |
|
|
| with gr.Blocks(title="DPR Retriever Demo", css=css, theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| "## DPR Retriever Demo\n" |
| "Dense retrieval over a small subset of **SQuAD** (validation 5%). " |
| "Enter a question and get the top-k passages ranked by cosine similarity." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| question = gr.Textbox( |
| label="Question", |
| value="Who discovered penicillin?", |
| lines=2, |
| placeholder="Ask something Wikipedia-ish…" |
| ) |
| topk = gr.Slider(1, 20, value=5, step=1, label="Top-K") |
| examples = gr.Examples( |
| examples=[ |
| ["Who discovered penicillin?", 5], |
| ["What is the capital of France?", 5], |
| ["When was the Eiffel Tower completed?", 5], |
| ["Who wrote Pride and Prejudice?", 5], |
| ], |
| inputs=[question, topk], |
| label="Examples" |
| ) |
| run = gr.Button("🔎 Retrieve", variant="primary") |
|
|
| with gr.Column(scale=1): |
| results = gr.Textbox( |
| label="Top passages", |
| lines=24, |
| show_copy_button=True |
| ) |
|
|
| run.click(retrieve, inputs=[question, topk], outputs=[results]) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|