File size: 4,999 Bytes
56689a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import os
import json
from pathlib import Path
from typing import List, Dict
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
# --- Configuration ---
INDEX_DIR = "./index_dir"
EMBED_MODEL = "intfloat/e5-base-v2"
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
FETCH_K = 40
TOP_K = 8
# --- Global variables to hold models and data ---
index = None
meta = None
embedder = None
reranker = None
# --- Utility Functions (adapted from pdf_semsearch.py) ---
def e5_prefix(text: str, is_query: bool, model_name: str) -> str:
"""Add E5-style prefixes if using an e5 model."""
if "e5" in model_name.lower():
return f"{'query' if is_query else 'passage'}: {text}"
return text
def read_metadata(meta_path: str) -> List[Dict]:
"""Reads metadata from a JSONL file."""
out = []
with open(meta_path, "r", encoding="utf-8") as f:
for line in f:
out.append(json.loads(line))
return out
def pretty_snippet(s: str, max_len: int = 320) -> str:
"""Cleans up and truncates text for display."""
s = " ".join(s.split())
return s if len(s) <= max_len else s[: max_len - 1] + "…"
# --- Model and Data Loading ---
def load_models_and_data():
"""Loads the FAISS index, metadata, and models into memory."""
global index, meta, embedder, reranker
index_path = os.path.join(INDEX_DIR, "faiss.index")
meta_path = os.path.join(INDEX_DIR, "meta.jsonl")
if not os.path.exists(index_path) or not os.path.exists(meta_path):
raise FileNotFoundError(
f"Index not found in '{INDEX_DIR}'. "
"Please run the indexing command from pdf_semsearch.py first."
)
print(f"[*] Loading FAISS index: {index_path}")
index = faiss.read_index(index_path)
print("[*] Loading metadata…")
meta = read_metadata(meta_path)
print(f"[*] Loading embedding model: {EMBED_MODEL}")
embedder = SentenceTransformer(EMBED_MODEL)
print(f"[*] Loading reranker model: {RERANKER_MODEL}")
reranker = CrossEncoder(RERANKER_MODEL)
print("[✓] Models and data loaded.")
# --- Search Function ---
def search(query: str):
"""
Performs semantic search on the loaded index.
Takes a user query, finds relevant chunks, reranks them, and returns formatted results.
"""
if not query or not query.strip():
return "Please enter a search query."
if not all([index, meta, embedder, reranker]):
return "Error: Models or data not loaded. Please check the console."
# 1. Embed the query
query_text = e5_prefix(query, is_query=True, model_name=EMBED_MODEL)
qvec = embedder.encode([query_text], normalize_embeddings=True).astype("float32")
# 2. Search the FAISS index
D, I = index.search(qvec, FETCH_K)
# 3. Retrieve candidates
candidates = []
for j, idx in enumerate(I[0]):
if idx == -1:
continue
rec = dict(meta[idx])
rec["ann_score"] = float(D[0][j])
candidates.append(rec)
if not candidates:
return "No results found."
# 4. Rerank the candidates
pairs = [(query, c["text"]) for c in candidates]
scores = reranker.predict(pairs)
for c, s in zip(candidates, scores):
c["rerank_score"] = float(s)
candidates.sort(key=lambda x: x["rerank_score"], reverse=True)
# 5. Format the top results for display
results = candidates[:TOP_K]
output = f"## Results for: \"{query}\"\n\n"
for i, r in enumerate(results, start=1):
base = Path(r["doc_path"]).name
score = r.get("rerank_score", r["ann_score"])
output += (
f"**{i}. {base} (Page: {r['page']}, Score: {score:.3f})**\n\n"
f"> {pretty_snippet(r['text'])}\n\n"
"---"
)
return output
# --- Gradio App ---
def create_gradio_app():
"""Creates and returns the Gradio interface."""
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown(
"""
# Semantic PDF Search
Enter a query to search through the indexed PDF documents.
The index must be created first using `pdf_semsearch.py`.
"""
)
with gr.Row():
query_input = gr.Textbox(
label="Search Query",
placeholder="e.g., KRAS G12C eligibility in lung cancer",
lines=1,
scale=4,
)
search_button = gr.Button("Search", variant="primary", scale=1)
results_output = gr.Markdown(label="Search Results")
search_button.click(
fn=search,
inputs=query_input,
outputs=results_output,
)
query_input.submit(
fn=search,
inputs=query_input,
outputs=results_output,
)
return iface
if __name__ == "__main__":
load_models_and_data()
app = create_gradio_app()
app.launch()
|