Jonas Leeb commited on
Commit ·
dc760b4
1
Parent(s): 0fbc2c7
bug fixes and usability improvements
Browse files
app.py
CHANGED
|
@@ -57,8 +57,8 @@ class ArxivSearch:
|
|
| 57 |
outputs=self.output_md
|
| 58 |
)
|
| 59 |
self.embedding_dropdown.change(
|
| 60 |
-
self.
|
| 61 |
-
inputs=[self.
|
| 62 |
outputs=self.output_md
|
| 63 |
)
|
| 64 |
self.plot_button.click(
|
|
@@ -73,11 +73,12 @@ class ArxivSearch:
|
|
| 73 |
)
|
| 74 |
|
| 75 |
self.load_data(dataset)
|
| 76 |
-
self.load_model(
|
| 77 |
-
self.load_model('
|
| 78 |
-
self.load_model('
|
| 79 |
-
self.load_model('
|
| 80 |
-
self.load_model('
|
|
|
|
| 81 |
|
| 82 |
self.iface.launch()
|
| 83 |
|
|
@@ -114,7 +115,6 @@ class ArxivSearch:
|
|
| 114 |
self.arxiv_ids.append(arxiv_id)
|
| 115 |
|
| 116 |
def plot_dense(self, embedding, pca, results_indices):
|
| 117 |
-
print(self.query_encoding.shape[0])
|
| 118 |
all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
|
| 119 |
all_data = embedding[all_indices]
|
| 120 |
pca.fit(all_data)
|
|
@@ -149,7 +149,9 @@ class ArxivSearch:
|
|
| 149 |
z=reduced_data[:, 2],
|
| 150 |
mode='markers',
|
| 151 |
marker=dict(size=3.5, color="#ffffff", opacity=0.2),
|
| 152 |
-
name='All Documents'
|
|
|
|
|
|
|
| 153 |
)
|
| 154 |
layout = go.Layout(
|
| 155 |
margin=dict(l=0, r=0, b=0, t=0),
|
|
@@ -172,7 +174,9 @@ class ArxivSearch:
|
|
| 172 |
z=reduced_results_points[:, 2],
|
| 173 |
mode='markers',
|
| 174 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
| 175 |
-
name='Results'
|
|
|
|
|
|
|
| 176 |
)
|
| 177 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
| 178 |
query_trace = go.Scatter3d(
|
|
@@ -181,7 +185,9 @@ class ArxivSearch:
|
|
| 181 |
z=query_point[:, 2],
|
| 182 |
mode='markers',
|
| 183 |
marker=dict(size=5, color='red', opacity=0.8),
|
| 184 |
-
name='Query'
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
|
| 187 |
else:
|
|
@@ -209,7 +215,7 @@ class ArxivSearch:
|
|
| 209 |
if not tokens:
|
| 210 |
return []
|
| 211 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
| 212 |
-
query_vec =
|
| 213 |
self.query_encoding = query_vec
|
| 214 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
| 215 |
top_indices = sims.argsort()[::-1][:top_n]
|
|
@@ -219,7 +225,6 @@ class ArxivSearch:
|
|
| 219 |
with torch.no_grad():
|
| 220 |
inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
| 221 |
outputs = self.model(**inputs)
|
| 222 |
-
# query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
| 223 |
query_vec = outputs.last_hidden_state[:, 0, :].numpy()
|
| 224 |
|
| 225 |
self.query_encoding = query_vec
|
|
@@ -251,6 +256,38 @@ class ArxivSearch:
|
|
| 251 |
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
|
| 252 |
print(f"sim, top_indices: {final_scores}, {top_indices}")
|
| 253 |
return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
def load_model(self, embedding):
|
| 256 |
self.embedding = embedding
|
|
@@ -291,8 +328,9 @@ class ArxivSearch:
|
|
| 291 |
def set_embedding(self, embedding):
|
| 292 |
self.embedding = embedding
|
| 293 |
|
| 294 |
-
def search_function(self, query, embedding):
|
| 295 |
self.set_embedding(embedding)
|
|
|
|
| 296 |
query = query.encode().decode('unicode_escape') # Interpret escape sequences
|
| 297 |
|
| 298 |
# Load or switch embedding model here if needed
|
|
|
|
| 57 |
outputs=self.output_md
|
| 58 |
)
|
| 59 |
self.embedding_dropdown.change(
|
| 60 |
+
self.model_switch,
|
| 61 |
+
inputs=[self.embedding_dropdown],
|
| 62 |
outputs=self.output_md
|
| 63 |
)
|
| 64 |
self.plot_button.click(
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
self.load_data(dataset)
|
| 76 |
+
self.load_model(embedding)
|
| 77 |
+
# self.load_model('tfidf')
|
| 78 |
+
# self.load_model('word2vec')
|
| 79 |
+
# self.load_model('bert')
|
| 80 |
+
# self.load_model('scibert')
|
| 81 |
+
# self.load_model('sbert')
|
| 82 |
|
| 83 |
self.iface.launch()
|
| 84 |
|
|
|
|
| 115 |
self.arxiv_ids.append(arxiv_id)
|
| 116 |
|
| 117 |
def plot_dense(self, embedding, pca, results_indices):
|
|
|
|
| 118 |
all_indices = list(set(results_indices) | set(range(min(5000, embedding.shape[0]))))
|
| 119 |
all_data = embedding[all_indices]
|
| 120 |
pca.fit(all_data)
|
|
|
|
| 149 |
z=reduced_data[:, 2],
|
| 150 |
mode='markers',
|
| 151 |
marker=dict(size=3.5, color="#ffffff", opacity=0.2),
|
| 152 |
+
name='All Documents',
|
| 153 |
+
text=[f"<br>: {self.arxiv_ids[i] if self.arxiv_ids[i] else self.documents[i].split()[:10]}" for i in range(len(self.documents))],
|
| 154 |
+
hoverinfo='text'
|
| 155 |
)
|
| 156 |
layout = go.Layout(
|
| 157 |
margin=dict(l=0, r=0, b=0, t=0),
|
|
|
|
| 174 |
z=reduced_results_points[:, 2],
|
| 175 |
mode='markers',
|
| 176 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
| 177 |
+
name='Results',
|
| 178 |
+
text=[f"<br>Snippet: {self.documents[i][:200]}" for i in results_indices],
|
| 179 |
+
hoverinfo='text'
|
| 180 |
)
|
| 181 |
if not self.embedding == "tfidf" and self.query_encoding is not None and self.query_encoding.shape[0] > 0:
|
| 182 |
query_trace = go.Scatter3d(
|
|
|
|
| 185 |
z=query_point[:, 2],
|
| 186 |
mode='markers',
|
| 187 |
marker=dict(size=5, color='red', opacity=0.8),
|
| 188 |
+
name='Query',
|
| 189 |
+
text=[f"<br>Query: {self.query}"],
|
| 190 |
+
hoverinfo='text'
|
| 191 |
)
|
| 192 |
fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
|
| 193 |
else:
|
|
|
|
| 215 |
if not tokens:
|
| 216 |
return []
|
| 217 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
| 218 |
+
query_vec = np.mean(vectors, axis=0).reshape(1, -1)
|
| 219 |
self.query_encoding = query_vec
|
| 220 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
| 221 |
top_indices = sims.argsort()[::-1][:top_n]
|
|
|
|
| 225 |
with torch.no_grad():
|
| 226 |
inputs = self.tokenizer((query+' ')*2, return_tensors="pt", truncation=True, max_length=512, padding='max_length')
|
| 227 |
outputs = self.model(**inputs)
|
|
|
|
| 228 |
query_vec = outputs.last_hidden_state[:, 0, :].numpy()
|
| 229 |
|
| 230 |
self.query_encoding = query_vec
|
|
|
|
| 256 |
top_indices = top_k_indices[final_scores.argsort()[::-1][:top_n]]
|
| 257 |
print(f"sim, top_indices: {final_scores}, {top_indices}")
|
| 258 |
return [(top_k_indices[i], final_scores[i]) for i in final_scores.argsort()[::-1][:top_n]]
|
| 259 |
+
|
| 260 |
+
def model_switch(self, embedding, progress=gr.Progress()):
|
| 261 |
+
if self.embedding != embedding:
|
| 262 |
+
old_embedding = self.embedding
|
| 263 |
+
print(f"Switching model to {embedding}")
|
| 264 |
+
self.load_model(embedding)
|
| 265 |
+
print(f"Loaded {embedding} model")
|
| 266 |
+
self.embedding = embedding
|
| 267 |
+
if old_embedding == "tfidf":
|
| 268 |
+
del self.tfidf_matrix
|
| 269 |
+
del self.feature_names
|
| 270 |
+
if old_embedding == "word2vec":
|
| 271 |
+
del self.word2vec_embeddings
|
| 272 |
+
del self.wv_model
|
| 273 |
+
if old_embedding == "bert":
|
| 274 |
+
del self.bert_embeddings
|
| 275 |
+
del self.tokenizer
|
| 276 |
+
del self.model
|
| 277 |
+
if old_embedding == "scibert":
|
| 278 |
+
del self.scibert_embeddings
|
| 279 |
+
del self.sci_tokenizer
|
| 280 |
+
del self.sci_model
|
| 281 |
+
if old_embedding == "sbert":
|
| 282 |
+
del self.sbert_model
|
| 283 |
+
del self.sbert_embedding
|
| 284 |
+
del self.cross_encoder
|
| 285 |
+
print(f"old embedding removed")
|
| 286 |
+
if hasattr(self, "query") and self.query:
|
| 287 |
+
return self.search_function(self.query, self.embedding)
|
| 288 |
+
else:
|
| 289 |
+
return "" # Or a message like "Model switched. Please enter a query."
|
| 290 |
+
return gr.update() # No change if embedding is the same
|
| 291 |
|
| 292 |
def load_model(self, embedding):
|
| 293 |
self.embedding = embedding
|
|
|
|
| 328 |
def set_embedding(self, embedding):
|
| 329 |
self.embedding = embedding
|
| 330 |
|
| 331 |
+
def search_function(self, query, embedding, progress=gr.Progress()):
|
| 332 |
self.set_embedding(embedding)
|
| 333 |
+
self.query = query
|
| 334 |
query = query.encode().decode('unicode_escape') # Interpret escape sequences
|
| 335 |
|
| 336 |
# Load or switch embedding model here if needed
|