Spaces:
Runtime error
Runtime error
| ### Imports | |
| from sentence_transformers import SentenceTransformer, util | |
| ### Classes and functions | |
| ##========================================================================================================== | |
| class SentTransfUtilities: | |
| ##========================================================================================================== | |
| """ | |
| Definition of attributes | |
| """ | |
| model = None | |
| __model_name = None | |
| ##========================================================================================================== | |
| """ | |
| Function: __init__ | |
| Arguments: | |
| - model_name: | |
| Options: | |
| - 'all-MiniLM-L6-v2 | |
| - 'nq-distilbert-base-v1' | |
| - 'paraphrase-multilingual-MiniLM-L12-v2' | |
| """ | |
| def __init__(self, model_name="all-MiniLM-L6-v2"): | |
| self.__model_name = model_name | |
| if self.model == None: | |
| print("Initializing the Sentence Transformer model") | |
| self.model = SentenceTransformer(self.__model_name) | |
| ##========================================================================================================== | |
| """ | |
| Function: get_embeddings() | |
| """ | |
| def get_embeddings(self, src_data): | |
| return self.model.encode(src_data, convert_to_tensor=True, device='cpu') | |
| ##========================================================================================================== | |
| """ | |
| Function: compute_cosine_similarity(query_embeddings, passage_embeddings) | |
| """ | |
| def compute_cosine_similarity(self, query_embeddings, passage_embeddings): | |
| #Compute cosine-similarities | |
| cosine_scores = util.cos_sim(query_embeddings, passage_embeddings) | |
| return cosine_scores | |
| ##========================================================================================================== | |
| """ | |
| Function: compute_dot_similarity(query_embeddings, passage_embeddings) | |
| Arguments: | |
| - query_embeddings | |
| - passage_embeddings | |
| """ | |
| def compute_dot_similarity(self, query_embeddings, passage_embeddings): | |
| #Compute dot-similarities | |
| dot_scores = util.dot_score(query_embeddings, passage_embeddings) | |
| return dot_scores | |
| ##========================================================================================================== | |
| """ | |
| Function: compute_semantic_search(query_embeddings, corpus_embeddings) | |
| Arguments: | |
| - query_embeddings | |
| - corpus_embeddings | |
| """ | |
| def compute_semantic_search(self, query_embeddings, corpus_embeddings): | |
| #Compute dot-similarities | |
| dot_scores = util.semantic_search(query_embeddings, corpus_embeddings) | |
| return dot_scores | |
| ##========================================================================================================== | |
| """ | |
| Function: compute_sentences_similarity(sentence_1, sentence_2, sim_func) | |
| Arguments: | |
| - sentence_1 | |
| - sentence_2 | |
| - sim_func: { "cosine", "dot" } | |
| """ | |
| def compute_sentences_similarity(self, sentence_1, sentence_2, sim_func="cosine"): | |
| embeddings_1 = self.get_embeddings(sentence_1) | |
| embeddings_2 = self.get_embeddings(sentence_2) | |
| scores = None | |
| if sim_func == "cosine": | |
| scores = self.compute_cosine_similarity(embeddings_1, embeddings_2) | |
| elif sim_func == "dot": | |
| scores = self.compute_dot_similarity(embeddings_1, embeddings_2) | |
| return scores | |
| ##========================================================================================================== | |
| ##========================================================================================================== | |