import math import numpy as np import torch import torch.nn.functional as F from torch import nn from datasets import load_dataset import faiss import gradio as gr from transformers import ( DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizerFast, DPRContextEncoderTokenizerFast ) # ------------------- # Device setup # ------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------- # Tokenizers # ------------------- q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained( "facebook/dpr-question_encoder-single-nq-base" ) p_tok = DPRContextEncoderTokenizerFast.from_pretrained( "facebook/dpr-ctx_encoder-single-nq-base" ) # ------------------- # DPR BiEncoder # ------------------- class DPRBiEncoderHF(nn.Module): def __init__(self, proj_dim=768, init_tau=0.07, freeze_backbones=False, use_projection=True): super().__init__() self.q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") self.p_encoder = DPRContextEncoder .from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") if freeze_backbones: for p in self.q_encoder.parameters(): p.requires_grad = False for p in self.p_encoder.parameters(): p.requires_grad = False self.use_projection = use_projection hidden = self.q_encoder.config.hidden_size # 768 if use_projection: self.q_proj = nn.Linear(hidden, proj_dim, bias=False) self.p_proj = nn.Linear(hidden, proj_dim, bias=False) out_dim = proj_dim else: self.q_proj = nn.Identity() self.p_proj = nn.Identity() out_dim = hidden # learnable temperature (not used in retrieval but kept from training) self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / init_tau), dtype=torch.float32)) self.out_dim = out_dim def _norm(self, x): # L2-normalize return F.normalize(x, dim=-1) @torch.no_grad() def encode_questions(self, **q_inputs): out = self.q_encoder(**q_inputs, return_dict=True) h = out.pooler_output z = self.q_proj(h) return self._norm(z) @torch.no_grad() def encode_passages(self, **p_inputs): out = self.p_encoder(**p_inputs, return_dict=True) h = out.pooler_output z = self.p_proj(h) return self._norm(z) # ------------------- # Reload trained model # ------------------- def load_model(weights_path="dpr_biencoder_hf.pt"): model = DPRBiEncoderHF( proj_dim=768, init_tau=0.07, freeze_backbones=False, use_projection=True ).to(device) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() return model model = load_model("dpr_biencoder_hf.pt") # ------------------- # Dataset + batching helpers # ------------------- def build_corpus(split="validation[:5%]"): ds = load_dataset("squad", split=split) uniq = list(dict.fromkeys(ds["context"])) corpus = [{"id": i, "title": "", "text": ctx} for i, ctx in enumerate(uniq)] return corpus def batch_p(titles, texts, max_len=256): enc = p_tok(text=titles, text_pair=texts, return_tensors="pt", padding=True, truncation=True, max_length=max_len) return {k: v.to(device) for k, v in enc.items()} def batch_q(questions, max_len=64): enc = q_tok(questions, return_tensors="pt", padding=True, truncation=True, max_length=max_len) return {k: v.to(device) for k, v in enc.items()} @torch.no_grad() def build_faiss_index(model, corpus, batch_size=64, p_max=256): all_vecs = [] for i in range(0, len(corpus), batch_size): titles = [c["title"] for c in corpus[i:i+batch_size]] texts = [c["text"] for c in corpus[i:i+batch_size]] zp = model.encode_passages(**batch_p(titles, texts, p_max)) all_vecs.append(zp.cpu().numpy().astype("float32")) doc_embs = np.vstack(all_vecs) dim = doc_embs.shape[1] index = faiss.IndexFlatIP(dim) # cosine on normalized vectors index.add(doc_embs) return index, corpus # Build a small validation index on launch (same as your original) corpus = build_corpus("validation[:5%]") index, corpus = build_faiss_index(model, corpus) # ------------------- # Retrieval # ------------------- @torch.no_grad() def encode_question(model, question, q_max=64): zq = model.encode_questions(**batch_q([question], q_max)) return zq.cpu().numpy().astype("float32") def retrieve(question, k=5): q = encode_question(model, question) scores, idxs = index.search(q, k) idxs, scores = idxs[0], scores[0] # Pretty formatting for a bigger textbox lines = [] for rank, (row, score) in enumerate(zip(idxs, scores), 1): doc = corpus[int(row)] snippet = doc["text"].replace("\n", " ") if len(snippet) > 700: snippet = snippet[:700] + "…" lines.append(f"{rank}. score={score:.3f} | id={doc['id']}\n {snippet}") return "\n\n".join(lines) # ------------------- # Gradio UI (bigger output + light polish) # ------------------- css = """ .gradio-container {max-width: 1200px !important;} .footer {visibility: hidden;} """ with gr.Blocks(title="DPR Retriever Demo", css=css, theme=gr.themes.Soft()) as demo: gr.Markdown( "## DPR Retriever Demo\n" "Dense retrieval over a small subset of **SQuAD** (validation 5%). " "Enter a question and get the top-k passages ranked by cosine similarity." ) with gr.Row(): with gr.Column(scale=1): question = gr.Textbox( label="Question", value="Who discovered penicillin?", lines=2, placeholder="Ask something Wikipedia-ish…" ) topk = gr.Slider(1, 20, value=5, step=1, label="Top-K") examples = gr.Examples( examples=[ ["Who discovered penicillin?", 5], ["What is the capital of France?", 5], ["When was the Eiffel Tower completed?", 5], ["Who wrote Pride and Prejudice?", 5], ], inputs=[question, topk], label="Examples" ) run = gr.Button("🔎 Retrieve", variant="primary") with gr.Column(scale=1): results = gr.Textbox( label="Top passages", lines=24, # <<< bigger output box show_copy_button=True ) run.click(retrieve, inputs=[question, topk], outputs=[results]) if __name__ == "__main__": demo.launch()