DPR-Mini / app.py
aayush226's picture
Update app.py
9213d88 verified
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from datasets import load_dataset
import faiss
import gradio as gr
from transformers import (
DPRQuestionEncoder, DPRContextEncoder,
DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast
)
# -------------------
# Device setup
# -------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------
# Tokenizers
# -------------------
q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
)
p_tok = DPRContextEncoderTokenizerFast.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base"
)
# -------------------
# DPR BiEncoder
# -------------------
class DPRBiEncoderHF(nn.Module):
def __init__(self, proj_dim=768, init_tau=0.07, freeze_backbones=False, use_projection=True):
super().__init__()
self.q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
self.p_encoder = DPRContextEncoder .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
if freeze_backbones:
for p in self.q_encoder.parameters(): p.requires_grad = False
for p in self.p_encoder.parameters(): p.requires_grad = False
self.use_projection = use_projection
hidden = self.q_encoder.config.hidden_size # 768
if use_projection:
self.q_proj = nn.Linear(hidden, proj_dim, bias=False)
self.p_proj = nn.Linear(hidden, proj_dim, bias=False)
out_dim = proj_dim
else:
self.q_proj = nn.Identity()
self.p_proj = nn.Identity()
out_dim = hidden
# learnable temperature (not used in retrieval but kept from training)
self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32))
self.out_dim = out_dim
def _norm(self, x): # L2-normalize
return F.normalize(x, dim=-1)
@torch.no_grad()
def encode_questions(self, **q_inputs):
out = self.q_encoder(**q_inputs, return_dict=True)
h = out.pooler_output
z = self.q_proj(h)
return self._norm(z)
@torch.no_grad()
def encode_passages(self, **p_inputs):
out = self.p_encoder(**p_inputs, return_dict=True)
h = out.pooler_output
z = self.p_proj(h)
return self._norm(z)
# -------------------
# Reload trained model
# -------------------
def load_model(weights_path="dpr_biencoder_hf.pt"):
model = DPRBiEncoderHF(
proj_dim=768,
init_tau=0.07,
freeze_backbones=False,
use_projection=True
).to(device)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
return model
model = load_model("dpr_biencoder_hf.pt")
# -------------------
# Dataset + batching helpers
# -------------------
def build_corpus(split="validation[:5%]"):
ds = load_dataset("squad", split=split)
uniq = list(dict.fromkeys(ds["context"]))
corpus = [{"id": i, "title": "", "text": ctx} for i, ctx in enumerate(uniq)]
return corpus
def batch_p(titles, texts, max_len=256):
enc = p_tok(text=titles, text_pair=texts, return_tensors="pt",
padding=True, truncation=True, max_length=max_len)
return {k: v.to(device) for k, v in enc.items()}
def batch_q(questions, max_len=64):
enc = q_tok(questions, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
return {k: v.to(device) for k, v in enc.items()}
@torch.no_grad()
def build_faiss_index(model, corpus, batch_size=64, p_max=256):
all_vecs = []
for i in range(0, len(corpus), batch_size):
titles = [c["title"] for c in corpus[i:i+batch_size]]
texts = [c["text"] for c in corpus[i:i+batch_size]]
zp = model.encode_passages(**batch_p(titles, texts, p_max))
all_vecs.append(zp.cpu().numpy().astype("float32"))
doc_embs = np.vstack(all_vecs)
dim = doc_embs.shape[1]
index = faiss.IndexFlatIP(dim) # cosine on normalized vectors
index.add(doc_embs)
return index, corpus
# Build a small validation index on launch (same as your original)
corpus = build_corpus("validation[:5%]")
index, corpus = build_faiss_index(model, corpus)
# -------------------
# Retrieval
# -------------------
@torch.no_grad()
def encode_question(model, question, q_max=64):
zq = model.encode_questions(**batch_q([question], q_max))
return zq.cpu().numpy().astype("float32")
def retrieve(question, k=5):
q = encode_question(model, question)
scores, idxs = index.search(q, k)
idxs, scores = idxs[0], scores[0]
# Pretty formatting for a bigger textbox
lines = []
for rank, (row, score) in enumerate(zip(idxs, scores), 1):
doc = corpus[int(row)]
snippet = doc["text"].replace("\n", " ")
if len(snippet) > 700:
snippet = snippet[:700] + "…"
lines.append(f"{rank}. score={score:.3f} | id={doc['id']}\n {snippet}")
return "\n\n".join(lines)
# -------------------
# Gradio UI (bigger output + light polish)
# -------------------
css = """
.gradio-container {max-width: 1200px !important;}
.footer {visibility: hidden;}
"""
with gr.Blocks(title="DPR Retriever Demo", css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"## DPR Retriever Demo\n"
"Dense retrieval over a small subset of **SQuAD** (validation 5%). "
"Enter a question and get the top-k passages ranked by cosine similarity."
)
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
label="Question",
value="Who discovered penicillin?",
lines=2,
placeholder="Ask something Wikipedia-ish…"
)
topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
examples = gr.Examples(
examples=[
["Who discovered penicillin?", 5],
["What is the capital of France?", 5],
["When was the Eiffel Tower completed?", 5],
["Who wrote Pride and Prejudice?", 5],
],
inputs=[question, topk],
label="Examples"
)
run = gr.Button("🔎 Retrieve", variant="primary")
with gr.Column(scale=1):
results = gr.Textbox(
label="Top passages",
lines=24, # <<< bigger output box
show_copy_button=True
)
run.click(retrieve, inputs=[question, topk], outputs=[results])
if __name__ == "__main__":
demo.launch()