MUVERA / app.py
redhairedshanks1's picture
Update app.py
03ba8f7 verified
# import gradio as gr
# import numpy as np
# import docx
# from typing import List, Tuple, Dict
# # ------------------------------
# # Embedding: Real SentenceTransformer (preferred), fallback to dummy
# # ------------------------------
# try:
# from sentence_transformers import SentenceTransformer
# _embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# EMBEDDING_DIM = _embedding_model.get_sentence_embedding_dimension()
# def embed_text(text: str) -> np.ndarray:
# return _embedding_model.encode(text, normalize_embeddings=True)
# USING_REAL = True
# except Exception:
# EMBEDDING_DIM = 32
# def embed_text(text: str) -> np.ndarray:
# np.random.seed(abs(hash(text)) % (10**6))
# return np.random.randn(EMBEDDING_DIM)
# USING_REAL = False
# # ------------------------------
# # Core Classes with Snippet Support
# # ------------------------------
# class MultiVectorDocument:
# def __init__(self, doc_id: str, vectors: List[np.ndarray], texts: List[str], metadata: Dict = None):
# self.doc_id = doc_id
# self.vectors = vectors # list of embeddings
# self.texts = texts # original paragraphs/chunks
# self.metadata = metadata or {}
# class SingleVectorIndex:
# def __init__(self, dim: int):
# self.dim = dim
# self.docs = {} # doc_id β†’ vector
# self.texts = {} # doc_id β†’ snippet preview
# def add_document(self, doc: MultiVectorDocument):
# centroid = np.mean(doc.vectors, axis=0)
# self.docs[doc.doc_id] = centroid / np.linalg.norm(centroid)
# # preview: first couple of paragraphs
# self.texts[doc.doc_id] = " | ".join(doc.texts[:2])
# def search(self, query_vec: np.ndarray, top_k=3):
# qn = query_vec / np.linalg.norm(query_vec)
# scores = [(doc_id,
# self.texts[doc_id],
# float(np.dot(qn, vec)))
# for doc_id, vec in self.docs.items()]
# return sorted(scores, key=lambda x: -x[2])[:top_k]
# class MuVERAIndex:
# def __init__(self, dim: int):
# self.dim = dim
# self.corpus = {}
# self.global_centroids = {}
# def add_document(self, doc: MultiVectorDocument):
# self.corpus[doc.doc_id] = doc
# centroid = np.mean(doc.vectors, axis=0)
# self.global_centroids[doc.doc_id] = centroid / np.linalg.norm(centroid)
# def search(self, query_vec: np.ndarray, top_k: int = 3):
# qn = query_vec / np.linalg.norm(query_vec)
# # Step 1: shortlist by centroid
# scores = [(doc_id, float(np.dot(qn, cent)))
# for doc_id, cent in self.global_centroids.items()]
# shortlist = sorted(scores, key=lambda x: -x[1])[: top_k * 2]
# # Step 2: fine-grained on passages
# reranked = []
# for doc_id, _ in shortlist:
# doc = self.corpus[doc_id]
# sims = [np.dot(qn, v/np.linalg.norm(v)) for v in doc.vectors]
# best_idx = int(np.argmax(sims))
# reranked.append((doc_id, doc.texts[best_idx], float(sims[best_idx])))
# return sorted(reranked, key=lambda x: -x[2])[:top_k]
# # ------------------------------
# # File Loaders (docx, txt)
# # ------------------------------
# def load_docx(path: str):
# doc = docx.Document(path)
# texts, vectors = [], []
# for para in doc.paragraphs:
# if para.text.strip():
# texts.append(para.text.strip())
# vectors.append(embed_text(para.text.strip()))
# return MultiVectorDocument(doc_id=path.split("/")[-1], vectors=vectors, texts=texts)
# def load_txt(path: str):
# with open(path, "r", encoding="utf-8") as f:
# lines = [line.strip() for line in f if line.strip()]
# vectors = [embed_text(line) for line in lines]
# return MultiVectorDocument(doc_id=path.split("/")[-1], vectors=vectors, texts=lines)
# # ------------------------------
# # App Initialization
# # ------------------------------
# dim = EMBEDDING_DIM
# single_index = SingleVectorIndex(dim)
# muvera_index = MuVERAIndex(dim)
# def add_files(files):
# added = []
# for f in files:
# if f.name.endswith(".docx"):
# doc = load_docx(f.name)
# elif f.name.endswith(".txt"):
# doc = load_txt(f.name)
# else:
# continue
# single_index.add_document(doc)
# muvera_index.add_document(doc)
# added.append(doc.doc_id)
# return f"βœ… Indexed: {', '.join(added)}" if added else "⚠️ No valid docs uploaded."
# def query(q: str, top_k: int = 3):
# if not q.strip():
# return "Please enter a query", "Please enter a query"
# q_vec = embed_text(q)
# single_results = single_index.search(q_vec, top_k)
# muvera_results = muvera_index.search(q_vec, top_k)
# def fmt(results):
# if not results:
# return "No results yet. Upload docs first."
# return "\n\n".join([
# f"{rank+1}. πŸ“„ {doc_id}\n ✨ Snippet: {snippet}\n πŸ”Ή Score={score:.3f}"
# for rank, (doc_id, snippet, score) in enumerate(results)
# ])
# return fmt(single_results), fmt(muvera_results)
# # ------------------------------
# # Gradio Interface
# # ------------------------------
# with gr.Blocks() as demo:
# gr.Markdown("## πŸ”Ž MuVERA Demo: Multi-Vector Retrieval vs Single Vector Search")
# gr.Markdown("Upload `.docx` or `.txt` files (small text docs), then compare retrieval methods.")
# with gr.Row():
# uploader = gr.File(file_types=[".docx", ".txt"], file_count="multiple")
# status = gr.Textbox(label="Index status")
# uploader.upload(add_files, uploader, status)
# q_box = gr.Textbox(label="Enter query", placeholder="Search something like: efficient retrieval methods...")
# topk_slider = gr.Slider(1, 5, value=3, step=1, label="Top-k Results")
# with gr.Row():
# out_single = gr.Textbox(label="Single-Vector Results", lines=12)
# out_muvera = gr.Textbox(label="MuVERA Results", lines=12)
# btn = gr.Button("Search πŸ”")
# btn.click(query, [q_box, topk_slider], [out_single, out_muvera])
# demo.launch()
import gradio as gr
import numpy as np
import docx
from typing import List, Tuple, Dict
# ------------------------------
# Embedding: Real SentenceTransformer (preferred), fallback to dummy
# ------------------------------
try:
from sentence_transformers import SentenceTransformer
_embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
EMBEDDING_DIM = _embedding_model.get_sentence_embedding_dimension()
def embed_text(text: str) -> np.ndarray:
return _embedding_model.encode(text, normalize_embeddings=True)
USING_REAL = True
except Exception:
EMBEDDING_DIM = 32
def embed_text(text: str) -> np.ndarray:
np.random.seed(abs(hash(text)) % (10**6))
return np.random.randn(EMBEDDING_DIM)
USING_REAL = False
# ------------------------------
# Core Classes
# ------------------------------
class MultiVectorDocument:
def __init__(self, doc_id: str, vectors: List[np.ndarray], texts: List[str], metadata: Dict = None):
self.doc_id = doc_id
self.vectors = vectors # list of embeddings
self.texts = texts # corresponding passages
self.metadata = metadata or {}
class SingleVectorIndex:
""" Naive single-vector index = each doc collapsed to one centroid. """
def __init__(self, dim: int):
self.dim = dim
self.docs = {}
self.texts = {}
def add_document(self, doc: MultiVectorDocument):
centroid = np.mean(doc.vectors, axis=0)
self.docs[doc.doc_id] = centroid / np.linalg.norm(centroid)
self.texts[doc.doc_id] = " | ".join(doc.texts[:2]) # preview of first passages
def search(self, query_vec: np.ndarray, top_k=3):
qn = query_vec / np.linalg.norm(query_vec)
scores = [(doc_id, self.texts[doc_id], float(np.dot(qn, vec)))
for doc_id, vec in self.docs.items()]
return sorted(scores, key=lambda x: -x[2])[:top_k]
class MuVERAIndex:
""" Multi-vector index with centroid prefilter, but returns best-N snippets across docs. """
def __init__(self, dim: int):
self.dim = dim
self.corpus = {}
self.global_centroids = {}
def add_document(self, doc: MultiVectorDocument):
self.corpus[doc.doc_id] = doc
centroid = np.mean(doc.vectors, axis=0)
self.global_centroids[doc.doc_id] = centroid / np.linalg.norm(centroid)
def search(self, query_vec: np.ndarray, top_k=3, per_doc_hits=2):
qn = query_vec / np.linalg.norm(query_vec)
# Step 1: shortlist docs by centroid
scores = [(doc_id, float(np.dot(qn, cent)))
for doc_id, cent in self.global_centroids.items()]
shortlist = sorted(scores, key=lambda x: -x[1])[: top_k * 3]
# Step 2: evaluate ALL passages in shortlisted docs
reranked = []
for doc_id, _ in shortlist:
doc = self.corpus[doc_id]
for passage, vec in zip(doc.texts, doc.vectors):
sim = np.dot(qn, vec/np.linalg.norm(vec))
reranked.append((doc_id, passage, float(sim)))
# Step 3: return globally best passages across docs
return sorted(reranked, key=lambda x: -x[2])[: top_k * per_doc_hits]
# ------------------------------
# File Loaders
# ------------------------------
def load_docx(path: str):
doc = docx.Document(path)
texts, vectors = [], []
for para in doc.paragraphs:
if para.text.strip():
texts.append(para.text.strip())
vectors.append(embed_text(para.text.strip()))
return MultiVectorDocument(doc_id=path.split("/")[-1], vectors=vectors, texts=texts)
def load_txt(path: str):
with open(path, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f if line.strip()]
vectors = [embed_text(line) for line in lines]
return MultiVectorDocument(doc_id=path.split("/")[-1], vectors=vectors, texts=lines)
# ------------------------------
# App State
# ------------------------------
dim = EMBEDDING_DIM
single_index = SingleVectorIndex(dim)
muvera_index = MuVERAIndex(dim)
# ------------------------------
# Functions for Gradio
# ------------------------------
def add_files(files):
added = []
for f in files:
if f.name.endswith(".docx"):
doc = load_docx(f.name)
elif f.name.endswith(".txt"):
doc = load_txt(f.name)
else:
continue
single_index.add_document(doc)
muvera_index.add_document(doc)
added.append(doc.doc_id)
return f"βœ… Indexed: {', '.join(added)}" if added else "⚠️ No valid docs uploaded."
def query(q: str, top_k: int = 3):
if not q.strip():
return "Please enter a query", "Please enter a query"
q_vec = embed_text(q)
single_results = single_index.search(q_vec, top_k)
muvera_results = muvera_index.search(q_vec, top_k, per_doc_hits=2)
def fmt(results, mode="doc"):
if not results:
return "No results yet. Upload docs first."
formatted = []
for rank, (doc_id, snippet, score) in enumerate(results):
formatted.append(
f"{rank+1}. πŸ“„ {doc_id}\n ✨ Snippet: {snippet}\n πŸ”Ή Score={score:.3f}"
)
return "\n\n".join(formatted)
return fmt(single_results, "doc"), fmt(muvera_results, "snippet")
# ------------------------------
# Gradio UI
# ------------------------------
with gr.Blocks() as demo:
gr.Markdown("## πŸ”Ž MuVERA Demo: Multi-Vector Retrieval vs Single-Vector Search")
gr.Markdown("Upload `.docx` or `.txt` files, then compare retrieval systems.")
with gr.Row():
uploader = gr.File(file_types=[".docx", ".txt"], file_count="multiple")
status = gr.Textbox(label="Index status")
uploader.upload(add_files, uploader, status)
q_box = gr.Textbox(label="Query", placeholder="Example: efficient retrieval methods")
topk_slider = gr.Slider(1, 5, value=3, step=1, label="Top-k Docs to Consider")
with gr.Row():
out_single = gr.Textbox(label="Single-Vector Results (Doc-level)", lines=10)
out_muvera = gr.Textbox(label="MuVERA Results (Top Snippets)", lines=15)
btn = gr.Button("Search πŸ”")
btn.click(query, [q_box, topk_slider], [out_single, out_muvera])
demo.launch()