Spaces:
Sleeping
Sleeping
| from huggingface_hub import hf_hub_download | |
| from gensim.models import Word2Vec | |
| from nltk import word_tokenize, sent_tokenize | |
| from pylatexenc.latex2text import LatexNodes2Text | |
| import faiss | |
| import duckdb | |
| import time | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import dask.dataframe as dd | |
| def get_db(path='arxiv.db'): | |
| return duckdb.connect(path) | |
| def get_fast_lookup(_model): | |
| vectors = _model.wv.vectors # NumPy matrix (fast) | |
| word_to_index = {word: idx for idx, word in enumerate(_model.wv.index_to_key)} | |
| return vectors, word_to_index | |
| def load_arxiv_dict(): | |
| con = duckdb.connect("arxiv.db") | |
| df = con.execute(""" | |
| SELECT column0, id, title, abstract, categories | |
| FROM arxiv | |
| """).fetchdf() | |
| # dictionary: column0 → row | |
| return { | |
| int(row["column0"]): { | |
| "id": row["id"], | |
| "title": row["title"], | |
| "abstract": row["abstract"], | |
| "categories": row["categories"] | |
| } | |
| for _, row in df.iterrows() | |
| } | |
| def query_neighbours(rows): | |
| global arxiv_dict | |
| return [arxiv_dict.get(int(x)) for x in rows if int(x) in arxiv_dict] | |
| def get_model(): | |
| model_path = hf_hub_download( | |
| repo_id="nullHawk/word2vec-skipgram-arxive", | |
| filename="word2vec_arxiv_skipgram.model" | |
| ) | |
| model_npy_path = hf_hub_download( | |
| repo_id="nullHawk/word2vec-skipgram-arxive", | |
| filename="word2vec_arxiv_skipgram.model.syn1neg.npy" | |
| ) | |
| model_wv_path2 = hf_hub_download( | |
| repo_id="nullHawk/word2vec-skipgram-arxive", | |
| filename="word2vec_arxiv_skipgram.model.wv.vectors.npy" | |
| ) | |
| return Word2Vec.load(model_path) | |
| def get_faiss_index(): | |
| return faiss.read_index("bin/faiss_search_index.bin") | |
| def run_semantic_search(query, top_k): | |
| global model, faiss_index, word_to_index, vectors | |
| index = faiss_index | |
| words = query.lower().split() | |
| vecs = [] | |
| start_t = time.time() | |
| for w in words: | |
| idx = word_to_index.get(w) | |
| if idx is not None: | |
| vecs.append(vectors[idx]) | |
| mid_t = time.time() | |
| print(f"Tokenization time: {mid_t - start_t}") | |
| if not vecs: | |
| return [] | |
| qvec = np.mean(vecs, axis=0).astype('float32').reshape(1, -1) | |
| faiss.normalize_L2(qvec) | |
| scores, neighbors = index.search(qvec, top_k) | |
| mid2_t = time.time() | |
| print(f"Search time : {mid2_t - mid_t}") | |
| result = query_neighbours(neighbors[0]) | |
| print(f"Query time : {time.time() - mid2_t}\n\n\n") | |
| return result | |
| #----------------------------------- | |
| # Global Variables | |
| #----------------------------------- | |
| model = get_model() | |
| faiss_index = get_faiss_index() | |
| db = get_db() | |
| vectors, word_to_index = get_fast_lookup(model) | |
| arxiv_dict = load_arxiv_dict() | |
| # ---------------------------------- | |
| # Streamlit Page Setup | |
| # ---------------------------------- | |
| st.set_page_config(page_title="ArXiv Semantic Search", layout="wide") | |
| st.title("ArXiv Semantic Search Engine") | |
| st.write("Search over millions of research papers using semantic similarity.") | |
| # Sidebar | |
| st.sidebar.header("Search Options") | |
| top_k = st.sidebar.slider("Top K Results", 5, 50, 10) | |
| # Main Search Bar | |
| query = st.text_input( | |
| "Enter your search query:", | |
| placeholder="e.g. diffusion models for text-to-image, graph neural networks, LLM alignment..." | |
| ) | |
| search_button = st.button("Search") | |
| # -------------------------------------------------------------- | |
| # Handle search click | |
| # -------------------------------------------------------------- | |
| if search_button and query.strip(): | |
| start_time = time.time() | |
| with st.spinner("Searching..."): | |
| results = run_semantic_search(query, top_k) | |
| end_time = time.time() | |
| elapsed = end_time - start_time | |
| st.write(f"**Your query took {elapsed:.3f} seconds**") | |
| if(len(results) != 0): | |
| st.header(f"Top {top_k} Results") | |
| # ---------------------------------------------------------- | |
| # Display results (card-style) | |
| # ---------------------------------------------------------- | |
| for i, paper in enumerate(results, start=1): | |
| st.markdown(f"### **[{i}. {LatexNodes2Text().latex_to_text(paper['title'].replace("\n", " ").strip())}](https://arxiv.org/abs/{paper['id']})**") | |
| st.markdown(f"**Categories:** {paper['categories']}") | |
| st.markdown(f"**Abstract:** {LatexNodes2Text().latex_to_text(paper["abstract"][:600])}...") | |
| st.markdown("---") | |
| else: | |
| st.markdown(f"No Results, either model is not trained on this word") | |