nullHawk commited on
Commit
2d50028
·
verified ·
1 Parent(s): 66c9852

fix: db optimizations

Browse files
Files changed (1) hide show
  1. app.py +67 -32
app.py CHANGED
@@ -1,10 +1,11 @@
1
  from huggingface_hub import hf_hub_download
2
  from gensim.models import Word2Vec
3
- from nltk import word_tokenize
4
  from pylatexenc.latex2text import LatexNodes2Text
5
 
6
  import faiss
7
  import duckdb
 
8
 
9
  import streamlit as st
10
  import numpy as np
@@ -16,17 +17,35 @@ def get_db(path='arxiv.db'):
16
  return duckdb.connect(path)
17
 
18
 
19
- def query_neighbours(rows: list):
20
- global db
21
- con = db
22
- rows = [int(x) for x in rows] # Convert numpy.int64 → Python int
23
- placeholders = ",".join("?" for _ in rows)
24
- df = con.execute(
25
- f"SELECT * FROM arxiv WHERE column0 IN ({placeholders})",
26
- rows,
27
- ).fetchdf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- return df.to_dict("records")
30
 
31
  @st.cache_resource
32
  def get_model():
@@ -52,24 +71,34 @@ def get_faiss_index():
52
 
53
 
54
  def run_semantic_search(query, top_k):
55
- global model
56
- global faiss_index
57
 
58
  index = faiss_index
59
 
60
- words = word_tokenize(query.lower())
61
  vecs = []
62
 
 
63
  for w in words:
64
- if w in model.wv:
65
- vecs.append(model.wv[w])
66
- if len(vecs) == 0:
 
 
 
 
67
  return []
 
68
  qvec = np.mean(vecs, axis=0).astype('float32').reshape(1, -1)
69
  faiss.normalize_L2(qvec)
 
70
  scores, neighbors = index.search(qvec, top_k)
 
 
 
 
 
71
 
72
- return query_neighbours(neighbors[0])
73
 
74
 
75
  #-----------------------------------
@@ -79,6 +108,8 @@ def run_semantic_search(query, top_k):
79
  model = get_model()
80
  faiss_index = get_faiss_index()
81
  db = get_db()
 
 
82
 
83
  # ----------------------------------
84
  # Streamlit Page Setup
@@ -105,20 +136,24 @@ search_button = st.button("Search")
105
  # Handle search click
106
  # --------------------------------------------------------------
107
  if search_button and query.strip():
 
108
  with st.spinner("Searching..."):
109
  results = run_semantic_search(query, top_k)
110
-
111
- st.header(f"Top {top_k} Results")
112
-
113
- # ----------------------------------------------------------
114
- # Display results (card-style)
115
- # ----------------------------------------------------------
116
- for i, paper in enumerate(results, start=1):
117
- st.markdown(f"### **{i}. {LatexNodes2Text().latex_to_text(paper['title'].replace("\n", " ").strip())}**")
118
-
119
- st.markdown(f"**Categories:** {paper['categories']}")
120
- st.markdown(f"**Abstract:** {LatexNodes2Text().latex_to_text(paper["abstract"][:600])}...")
121
- st.markdown(f"[View on arXiv](https://arxiv.org/abs/{paper['id']})")
122
-
123
- st.markdown("---")
 
 
 
124
 
 
1
  from huggingface_hub import hf_hub_download
2
  from gensim.models import Word2Vec
3
+ from nltk import word_tokenize, sent_tokenize
4
  from pylatexenc.latex2text import LatexNodes2Text
5
 
6
  import faiss
7
  import duckdb
8
+ import time
9
 
10
  import streamlit as st
11
  import numpy as np
 
17
  return duckdb.connect(path)
18
 
19
 
20
+ @st.cache_resource
21
+ def get_fast_lookup(_model):
22
+ vectors = _model.wv.vectors # NumPy matrix (fast)
23
+ word_to_index = {word: idx for idx, word in enumerate(_model.wv.index_to_key)}
24
+ return vectors, word_to_index
25
+
26
+ @st.cache_resource
27
+ def load_arxiv_dict():
28
+ con = duckdb.connect("arxiv.db")
29
+ df = con.execute("""
30
+ SELECT column0, id, title, abstract, categories
31
+ FROM arxiv
32
+ """).fetchdf()
33
+
34
+ # dictionary: column0 → row
35
+ return {
36
+ int(row["column0"]): {
37
+ "id": row["id"],
38
+ "title": row["title"],
39
+ "abstract": row["abstract"],
40
+ "categories": row["categories"]
41
+ }
42
+ for _, row in df.iterrows()
43
+ }
44
+
45
+ def query_neighbours(rows):
46
+ global arxiv_dict
47
+ return [arxiv_dict.get(int(x)) for x in rows if int(x) in arxiv_dict]
48
 
 
49
 
50
  @st.cache_resource
51
  def get_model():
 
71
 
72
 
73
  def run_semantic_search(query, top_k):
74
+ global model, faiss_index, word_to_index, vectors
 
75
 
76
  index = faiss_index
77
 
78
+ words = query.lower().split()
79
  vecs = []
80
 
81
+ start_t = time.time()
82
  for w in words:
83
+ idx = word_to_index.get(w)
84
+ if idx is not None:
85
+ vecs.append(vectors[idx])
86
+ mid_t = time.time()
87
+ print(f"Tokenization time: {mid_t - start_t}")
88
+
89
+ if not vecs:
90
  return []
91
+
92
  qvec = np.mean(vecs, axis=0).astype('float32').reshape(1, -1)
93
  faiss.normalize_L2(qvec)
94
+
95
  scores, neighbors = index.search(qvec, top_k)
96
+ mid2_t = time.time()
97
+ print(f"Search time : {mid2_t - mid_t}")
98
+ result = query_neighbours(neighbors[0])
99
+ print(f"Query time : {time.time() - mid2_t}\n\n\n")
100
+ return result
101
 
 
102
 
103
 
104
  #-----------------------------------
 
108
  model = get_model()
109
  faiss_index = get_faiss_index()
110
  db = get_db()
111
+ vectors, word_to_index = get_fast_lookup(model)
112
+ arxiv_dict = load_arxiv_dict()
113
 
114
  # ----------------------------------
115
  # Streamlit Page Setup
 
136
  # Handle search click
137
  # --------------------------------------------------------------
138
  if search_button and query.strip():
139
+ start_time = time.time()
140
  with st.spinner("Searching..."):
141
  results = run_semantic_search(query, top_k)
142
+ end_time = time.time()
143
+ elapsed = end_time - start_time
144
+ st.write(f"**Your query took {elapsed:.3f} seconds**")
145
+ if(len(results) != 0):
146
+ st.header(f"Top {top_k} Results")
147
+
148
+ # ----------------------------------------------------------
149
+ # Display results (card-style)
150
+ # ----------------------------------------------------------
151
+ for i, paper in enumerate(results, start=1):
152
+ st.markdown(f"### **[{i}. {LatexNodes2Text().latex_to_text(paper['title'].replace("\n", " ").strip())}](https://arxiv.org/abs/{paper['id']})**")
153
+
154
+ st.markdown(f"**Categories:** {paper['categories']}")
155
+ st.markdown(f"**Abstract:** {LatexNodes2Text().latex_to_text(paper["abstract"][:600])}...")
156
+ st.markdown("---")
157
+ else:
158
+ st.markdown(f"No Results, either model is not trained on this word")
159