Spaces:
Runtime error
Runtime error
bug fixes
Browse files
app.py
CHANGED
|
@@ -50,7 +50,7 @@ def predict(text: str = sample_text, top_k: int=3):
|
|
| 50 |
# query = prepare_query(tokenizer, text)
|
| 51 |
index_data, faiss_index = index
|
| 52 |
# takes only the [CLS] embedding (for now)
|
| 53 |
-
query = model(text)[0][0].numpy().reshape(1, -1)
|
| 54 |
|
| 55 |
scores, indices = faiss_index.search(query, top_k)
|
| 56 |
scores, indices = scores.tolist(), indices.tolist()
|
|
|
|
| 50 |
# query = prepare_query(tokenizer, text)
|
| 51 |
index_data, faiss_index = index
|
| 52 |
# takes only the [CLS] embedding (for now)
|
| 53 |
+
query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
|
| 54 |
|
| 55 |
scores, indices = faiss_index.search(query, top_k)
|
| 56 |
scores, indices = scores.tolist(), indices.tolist()
|