Jonas Leeb
commited on
Commit
·
6c71bbc
1
Parent(s):
65f9879
added plot
Browse files
app.py
CHANGED
|
@@ -8,6 +8,9 @@ from transformers import BertTokenizer, BertModel
|
|
| 8 |
import numpy as np
|
| 9 |
from datasets import load_dataset
|
| 10 |
from gensim.models import KeyedVectors
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
|
|
@@ -19,6 +22,7 @@ class ArxivSearch:
|
|
| 19 |
self.titles = []
|
| 20 |
self.raw_texts = []
|
| 21 |
self.arxiv_ids = []
|
|
|
|
| 22 |
|
| 23 |
self.embedding_dropdown = gr.Dropdown(
|
| 24 |
choices=["tfidf", "word2vec", "bert"],
|
|
@@ -26,16 +30,48 @@ class ArxivSearch:
|
|
| 26 |
label="Model"
|
| 27 |
)
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
self.load_data(dataset)
|
| 41 |
# self.load_model(embedding)
|
|
@@ -45,15 +81,6 @@ class ArxivSearch:
|
|
| 45 |
|
| 46 |
self.iface.launch()
|
| 47 |
|
| 48 |
-
|
| 49 |
-
# # --- Load data and embeddings ---
|
| 50 |
-
# with open("feature_names.txt", "r") as f:
|
| 51 |
-
# feature_names = [line.strip() for line in f]
|
| 52 |
-
|
| 53 |
-
# tfidf_matrix = load_npz("tfidf_matrix_train.npz")
|
| 54 |
-
|
| 55 |
-
# Load dataset and initialize search engine
|
| 56 |
-
|
| 57 |
def load_data(self, dataset):
|
| 58 |
train_data = dataset["train"]
|
| 59 |
for item in train_data.select(range(len(train_data))):
|
|
@@ -99,6 +126,57 @@ class ArxivSearch:
|
|
| 99 |
scores.append((doc_idx, doc_score))
|
| 100 |
scores.sort(key=lambda x: x[1], reverse=True)
|
| 101 |
return scores[:top_n]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def word2vec_search(self, query, top_n=5):
|
| 104 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
|
@@ -163,7 +241,12 @@ class ArxivSearch:
|
|
| 163 |
return "No results found."
|
| 164 |
|
| 165 |
if not results:
|
|
|
|
| 166 |
return "No results found."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
output = ""
|
| 169 |
display_rank = 1
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from datasets import load_dataset
|
| 10 |
from gensim.models import KeyedVectors
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
from sklearn.decomposition import PCA
|
| 13 |
+
|
| 14 |
|
| 15 |
|
| 16 |
|
|
|
|
| 22 |
self.titles = []
|
| 23 |
self.raw_texts = []
|
| 24 |
self.arxiv_ids = []
|
| 25 |
+
self.last_results = []
|
| 26 |
|
| 27 |
self.embedding_dropdown = gr.Dropdown(
|
| 28 |
choices=["tfidf", "word2vec", "bert"],
|
|
|
|
| 30 |
label="Model"
|
| 31 |
)
|
| 32 |
|
| 33 |
+
|
| 34 |
+
# Add a button to show the 3D plot
|
| 35 |
+
self.plot_button = gr.Button("Show 3D Plot")
|
| 36 |
+
|
| 37 |
+
# Define the interface using Blocks for more flexibility
|
| 38 |
+
with gr.Blocks() as self.iface:
|
| 39 |
+
gr.Markdown("# arXiv Search Engine")
|
| 40 |
+
gr.Markdown("Search arXiv papers by keyword and embedding model.")
|
| 41 |
+
with gr.Row():
|
| 42 |
+
self.query_box = gr.Textbox(lines=1, placeholder="Enter your search query", label="Query")
|
| 43 |
+
self.embedding_dropdown.render()
|
| 44 |
+
self.plot_button.render()
|
| 45 |
+
with gr.Row():
|
| 46 |
+
self.plot_output = gr.Plot()
|
| 47 |
+
self.output_md = gr.Markdown()
|
| 48 |
+
|
| 49 |
+
self.query_box.submit(
|
| 50 |
+
self.search_function,
|
| 51 |
+
inputs=[self.query_box, self.embedding_dropdown],
|
| 52 |
+
outputs=self.output_md
|
| 53 |
+
)
|
| 54 |
+
self.embedding_dropdown.change(
|
| 55 |
+
self.search_function,
|
| 56 |
+
inputs=[self.query_box, self.embedding_dropdown],
|
| 57 |
+
outputs=self.output_md
|
| 58 |
+
)
|
| 59 |
+
self.plot_button.click(
|
| 60 |
+
self.plot_3d_embeddings,
|
| 61 |
+
inputs=[self.embedding_dropdown],
|
| 62 |
+
outputs=self.plot_output
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# self.iface = gr.Interface(
|
| 66 |
+
# fn=self.search_function,
|
| 67 |
+
# inputs=[
|
| 68 |
+
# gr.Textbox(lines=1, placeholder="Enter your search query"),
|
| 69 |
+
# self.embedding_dropdown
|
| 70 |
+
# ],
|
| 71 |
+
# outputs=gr.Markdown(),
|
| 72 |
+
# title="arXiv Search Engine",
|
| 73 |
+
# description="Search arXiv papers by keyword and embedding model.",
|
| 74 |
+
# )
|
| 75 |
|
| 76 |
self.load_data(dataset)
|
| 77 |
# self.load_model(embedding)
|
|
|
|
| 81 |
|
| 82 |
self.iface.launch()
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def load_data(self, dataset):
|
| 85 |
train_data = dataset["train"]
|
| 86 |
for item in train_data.select(range(len(train_data))):
|
|
|
|
| 126 |
scores.append((doc_idx, doc_score))
|
| 127 |
scores.sort(key=lambda x: x[1], reverse=True)
|
| 128 |
return scores[:top_n]
|
| 129 |
+
|
| 130 |
+
def plot_3d_embeddings(self, embedding):
|
| 131 |
+
# Example: plot random points, replace with your embeddings
|
| 132 |
+
pca = PCA(n_components=3)
|
| 133 |
+
results_indices = [i[0] for i in self.last_results]
|
| 134 |
+
if embedding == "tfidf":
|
| 135 |
+
reduced_data = pca.fit_transform(self.tfidf_matrix[:5000].toarray())
|
| 136 |
+
reduced_results_points = pca.transform(self.tfidf_matrix[results_indices].toarray()) if len(results_indices) > 0 else np.empty((0, 3))
|
| 137 |
+
|
| 138 |
+
elif embedding == "word2vec":
|
| 139 |
+
reduced_data = pca.fit_transform(self.word2vec_embeddings[:5000])
|
| 140 |
+
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
| 141 |
+
|
| 142 |
+
elif embedding == "bert":
|
| 143 |
+
reduced_data = pca.fit_transform(self.bert_embeddings[:5000])
|
| 144 |
+
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unsupported embedding type: {embedding}")
|
| 147 |
+
trace = go.Scatter3d(
|
| 148 |
+
x=reduced_data[:, 0],
|
| 149 |
+
y=reduced_data[:, 1],
|
| 150 |
+
z=reduced_data[:, 2],
|
| 151 |
+
mode='markers',
|
| 152 |
+
marker=dict(size=3.5, color='white', opacity=0.4),
|
| 153 |
+
)
|
| 154 |
+
layout = go.Layout(
|
| 155 |
+
margin=dict(l=0, r=0, b=0, t=0),
|
| 156 |
+
scene=dict(
|
| 157 |
+
xaxis_title='X',
|
| 158 |
+
yaxis_title='Y',
|
| 159 |
+
zaxis_title='Z',
|
| 160 |
+
xaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
| 161 |
+
yaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
| 162 |
+
zaxis=dict(backgroundcolor='black', color='white', gridcolor='gray', zerolinecolor='gray'),
|
| 163 |
+
),
|
| 164 |
+
paper_bgcolor='black', # Outside the plotting area
|
| 165 |
+
plot_bgcolor='black', # Plotting area
|
| 166 |
+
font=dict(color='white') # Axis and legend text
|
| 167 |
+
)
|
| 168 |
+
if len(reduced_results_points) > 0:
|
| 169 |
+
results_trace = go.Scatter3d(
|
| 170 |
+
x=reduced_results_points[:, 0],
|
| 171 |
+
y=reduced_results_points[:, 1],
|
| 172 |
+
z=reduced_results_points[:, 2],
|
| 173 |
+
mode='markers',
|
| 174 |
+
marker=dict(size=3.5, color='orange', opacity=0.9),
|
| 175 |
+
)
|
| 176 |
+
fig = go.Figure(data=[trace, results_trace], layout=layout)
|
| 177 |
+
else:
|
| 178 |
+
fig = go.Figure(data=[trace], layout=layout)
|
| 179 |
+
return fig
|
| 180 |
|
| 181 |
def word2vec_search(self, query, top_n=5):
|
| 182 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
|
|
|
| 241 |
return "No results found."
|
| 242 |
|
| 243 |
if not results:
|
| 244 |
+
self.last_results = []
|
| 245 |
return "No results found."
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if results:
|
| 249 |
+
self.last_results = results
|
| 250 |
|
| 251 |
output = ""
|
| 252 |
display_rank = 1
|