Spaces:
Sleeping
Sleeping
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers import util | |
| import torch | |
| def predict( | |
| query: str, | |
| corpus_embeddings: torch.Tensor, | |
| corpus_labels: list, | |
| model: SentenceTransformer, | |
| top_k: int = 5, | |
| ) -> list: | |
| query_embedding = model.encode([query]) | |
| result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k) | |
| result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]] | |
| return result_predictions | |