nullHawk's picture
fix: db optimizations
2d50028 verified
from huggingface_hub import hf_hub_download
from gensim.models import Word2Vec
from nltk import word_tokenize, sent_tokenize
from pylatexenc.latex2text import LatexNodes2Text
import faiss
import duckdb
import time
import streamlit as st
import numpy as np
import pandas as pd
import dask.dataframe as dd
@st.cache_resource
def get_db(path='arxiv.db'):
return duckdb.connect(path)
@st.cache_resource
def get_fast_lookup(_model):
vectors = _model.wv.vectors # NumPy matrix (fast)
word_to_index = {word: idx for idx, word in enumerate(_model.wv.index_to_key)}
return vectors, word_to_index
@st.cache_resource
def load_arxiv_dict():
con = duckdb.connect("arxiv.db")
df = con.execute("""
SELECT column0, id, title, abstract, categories
FROM arxiv
""").fetchdf()
# dictionary: column0 → row
return {
int(row["column0"]): {
"id": row["id"],
"title": row["title"],
"abstract": row["abstract"],
"categories": row["categories"]
}
for _, row in df.iterrows()
}
def query_neighbours(rows):
global arxiv_dict
return [arxiv_dict.get(int(x)) for x in rows if int(x) in arxiv_dict]
@st.cache_resource
def get_model():
model_path = hf_hub_download(
repo_id="nullHawk/word2vec-skipgram-arxive",
filename="word2vec_arxiv_skipgram.model"
)
model_npy_path = hf_hub_download(
repo_id="nullHawk/word2vec-skipgram-arxive",
filename="word2vec_arxiv_skipgram.model.syn1neg.npy"
)
model_wv_path2 = hf_hub_download(
repo_id="nullHawk/word2vec-skipgram-arxive",
filename="word2vec_arxiv_skipgram.model.wv.vectors.npy"
)
return Word2Vec.load(model_path)
@st.cache_resource
def get_faiss_index():
return faiss.read_index("bin/faiss_search_index.bin")
def run_semantic_search(query, top_k):
global model, faiss_index, word_to_index, vectors
index = faiss_index
words = query.lower().split()
vecs = []
start_t = time.time()
for w in words:
idx = word_to_index.get(w)
if idx is not None:
vecs.append(vectors[idx])
mid_t = time.time()
print(f"Tokenization time: {mid_t - start_t}")
if not vecs:
return []
qvec = np.mean(vecs, axis=0).astype('float32').reshape(1, -1)
faiss.normalize_L2(qvec)
scores, neighbors = index.search(qvec, top_k)
mid2_t = time.time()
print(f"Search time : {mid2_t - mid_t}")
result = query_neighbours(neighbors[0])
print(f"Query time : {time.time() - mid2_t}\n\n\n")
return result
#-----------------------------------
# Global Variables
#-----------------------------------
model = get_model()
faiss_index = get_faiss_index()
db = get_db()
vectors, word_to_index = get_fast_lookup(model)
arxiv_dict = load_arxiv_dict()
# ----------------------------------
# Streamlit Page Setup
# ----------------------------------
st.set_page_config(page_title="ArXiv Semantic Search", layout="wide")
st.title("ArXiv Semantic Search Engine")
st.write("Search over millions of research papers using semantic similarity.")
# Sidebar
st.sidebar.header("Search Options")
top_k = st.sidebar.slider("Top K Results", 5, 50, 10)
# Main Search Bar
query = st.text_input(
"Enter your search query:",
placeholder="e.g. diffusion models for text-to-image, graph neural networks, LLM alignment..."
)
search_button = st.button("Search")
# --------------------------------------------------------------
# Handle search click
# --------------------------------------------------------------
if search_button and query.strip():
start_time = time.time()
with st.spinner("Searching..."):
results = run_semantic_search(query, top_k)
end_time = time.time()
elapsed = end_time - start_time
st.write(f"**Your query took {elapsed:.3f} seconds**")
if(len(results) != 0):
st.header(f"Top {top_k} Results")
# ----------------------------------------------------------
# Display results (card-style)
# ----------------------------------------------------------
for i, paper in enumerate(results, start=1):
st.markdown(f"### **[{i}. {LatexNodes2Text().latex_to_text(paper['title'].replace("\n", " ").strip())}](https://arxiv.org/abs/{paper['id']})**")
st.markdown(f"**Categories:** {paper['categories']}")
st.markdown(f"**Abstract:** {LatexNodes2Text().latex_to_text(paper["abstract"][:600])}...")
st.markdown("---")
else:
st.markdown(f"No Results, either model is not trained on this word")