Spaces:
Runtime error
Runtime error
| import gzip | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import jax.numpy as jnp | |
| import tqdm | |
| from sentence_transformers import util | |
| from typing import List, Union | |
| import torch | |
| from backend.utils import load_model, filter_questions, load_embeddings | |
| def cos_sim(a, b): | |
| return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) | |
| # We get similarity between embeddings. | |
| def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict): | |
| print(model_name) | |
| model = load_model(model_name, model_dict) | |
| # Creating embeddings | |
| if hasattr(model, 'encode'): | |
| anchor_emb = model.encode(anchor)[None, :] | |
| inputs_emb = model.encode(inputs) | |
| else: | |
| assert len(model) == 2 | |
| anchor_emb = model[0].encode(anchor)[None, :] | |
| inputs_emb = model[1].encode(inputs) | |
| # Obtaining similarity | |
| similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb))) | |
| # Returning a Pandas' dataframe | |
| d = {'inputs': inputs, | |
| 'score': [round(similarity[i], 3) for i in range(len(similarity))]} | |
| df = pd.DataFrame(d, columns=['inputs', 'score']) | |
| return df | |
| # Search | |
| def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict): | |
| # Proceeding with model | |
| print(model_name) | |
| assert model_name == "mpnet_qa" | |
| model = load_model(model_name, model_dict) | |
| # Creating embeddings | |
| query_emb = model.encode(anchor, convert_to_tensor=True)[None, :] | |
| print("loading embeddings") | |
| corpus_emb = load_embeddings() | |
| # Getting hits | |
| hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0] | |
| filtered_posts = filter_questions("python") | |
| print(f"{len(filtered_posts)} posts found with tag: python") | |
| hits_titles = [] | |
| hits_scores = [] | |
| urls = [] | |
| for hit in hits: | |
| post = filtered_posts[hit['corpus_id']] | |
| hits_titles.append(post['title']) | |
| hits_scores.append("{:.3f}".format(hit['score'])) | |
| urls.append(f"https://stackoverflow.com/q/{post['id']}") | |
| return hits_titles, hits_scores, urls | |