| import os |
|
|
| import ast |
| import numpy as np |
| import pandas as pd |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| def check_embeddings(filename, model, column): |
| if os.path.isfile(filename): |
| embeddings = np.load(filename) |
| else: |
| embeddings = model.encode(column.values) |
| np.save(filename, embeddings) |
| return embeddings |
|
|
| def find_neighbors(user_question_embedding, embeddings, k=1): |
| user_question_tile = np.tile(user_question_embedding, (len(embeddings), 1)) |
| distances = cosine_similarity(embeddings.reshape(len(embeddings), -1), |
| user_question_tile)[:, 0] |
| neighbors = np.argsort(distances)[::-1][:k] |
| return neighbors, distances[neighbors] |
|
|
| def load_data(config, model): |
| off_topic = pd.read_csv(config["OT_CSV_FILENAME"], index_col=None) |
| off_topic = off_topic[off_topic.Questions == "Off topic"] |
| off_topic["Answers"] = off_topic.Answers.str.split(" ~ ") |
| df = pd.read_csv(config["CSV_FILENAME"], index_col=None) |
| df_dedupe = df.drop_duplicates("Answers").reset_index(drop=True) |
| |
| |
| |
| |
| |
| |
| check_embeddings(config["QUESTIONS_FILENAME"], model, df["Variations_Q"]) |
| check_embeddings(config["ANSWERS_FILENAME"], model, df_dedupe["Answers"].str[6:]) |
|
|
|
|