Jonas Leeb
commited on
Commit
·
5355a96
1
Parent(s):
dfc89a9
small updates
Browse files
app.py
CHANGED
|
@@ -149,7 +149,7 @@ class ArxivSearch:
|
|
| 149 |
y=reduced_data[:, 1],
|
| 150 |
z=reduced_data[:, 2],
|
| 151 |
mode='markers',
|
| 152 |
-
marker=dict(size=3.5, color=
|
| 153 |
name='All Documents'
|
| 154 |
)
|
| 155 |
layout = go.Layout(
|
|
@@ -223,6 +223,21 @@ class ArxivSearch:
|
|
| 223 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 224 |
return [(i, sims[i]) for i in top_indices]
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
def load_model(self, embedding):
|
| 227 |
if embedding == "tfidf":
|
| 228 |
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
|
|
|
|
| 149 |
y=reduced_data[:, 1],
|
| 150 |
z=reduced_data[:, 2],
|
| 151 |
mode='markers',
|
| 152 |
+
marker=dict(size=3.5, color="#ffffff", opacity=0.2),
|
| 153 |
name='All Documents'
|
| 154 |
)
|
| 155 |
layout = go.Layout(
|
|
|
|
| 223 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 224 |
return [(i, sims[i]) for i in top_indices]
|
| 225 |
|
| 226 |
+
def bert_search_2(self, query, top_n=5):
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
| 229 |
+
outputs = self.model(**inputs)
|
| 230 |
+
token_embeddings = outputs.last_hidden_state
|
| 231 |
+
attention_mask = inputs['attention_mask']
|
| 232 |
+
mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 233 |
+
sentence_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
|
| 234 |
+
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
| 235 |
+
query_vec = sentence_embeddings / sum_mask
|
| 236 |
+
self.query_encoding = query_vec
|
| 237 |
+
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
|
| 238 |
+
top_indices = sims.argsort()[::-1][:top_n]
|
| 239 |
+
return [(i, sims[i]) for i in top_indices]
|
| 240 |
+
|
| 241 |
def load_model(self, embedding):
|
| 242 |
if embedding == "tfidf":
|
| 243 |
self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
|