Spaces:
Sleeping
Sleeping
Upload main.py
Browse files
main.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import nltk
|
| 3 |
import benepar
|
|
@@ -24,9 +25,9 @@ benepar.download('benepar_en3_large')
|
|
| 24 |
# Load dataset
|
| 25 |
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
restaurant_by_source
|
| 30 |
|
| 31 |
# Load precomputed TF-IDF features
|
| 32 |
restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
|
|
@@ -77,7 +78,7 @@ def retrieve_candidates(query: str, n_candidates: int):
|
|
| 77 |
return candidates_idx
|
| 78 |
|
| 79 |
|
| 80 |
-
def rerank(candidates_idx: np.ndarray, n_rec: int = 10, ) -> list:
|
| 81 |
|
| 82 |
# Get popularity scores for stage 1 candidates
|
| 83 |
rerank_scores = data.loc[candidates_idx, "pop_score"].values
|
|
@@ -91,8 +92,8 @@ def rerank(candidates_idx: np.ndarray, n_rec: int = 10, ) -> list:
|
|
| 91 |
|
| 92 |
return restaurant_ids
|
| 93 |
|
| 94 |
-
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30):
|
| 95 |
query_clean = clean_text(query)
|
| 96 |
-
candidates_idx = retrieve_candidates(query_clean, n_candidates)
|
| 97 |
-
restaurant_ids = rerank(candidates_idx, n_rec)
|
| 98 |
return restaurant_ids
|
|
|
|
| 1 |
+
import json
|
| 2 |
import torch
|
| 3 |
import nltk
|
| 4 |
import benepar
|
|
|
|
| 25 |
# Load dataset
|
| 26 |
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
|
| 27 |
|
| 28 |
+
# Load restaurant_by_source
|
| 29 |
+
with open("data/restaurant_by_source.json", "r") as f:
|
| 30 |
+
restaurant_by_source = json.load(f)
|
| 31 |
|
| 32 |
# Load precomputed TF-IDF features
|
| 33 |
restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
|
|
|
|
| 78 |
return candidates_idx
|
| 79 |
|
| 80 |
|
| 81 |
+
def rerank(candidates_idx: np.ndarray, n_rec: int = 10, data_source: str) -> list:
|
| 82 |
|
| 83 |
# Get popularity scores for stage 1 candidates
|
| 84 |
rerank_scores = data.loc[candidates_idx, "pop_score"].values
|
|
|
|
| 92 |
|
| 93 |
return restaurant_ids
|
| 94 |
|
| 95 |
+
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_source: str = None):
|
| 96 |
query_clean = clean_text(query)
|
| 97 |
+
candidates_idx = retrieve_candidates(query_clean, n_candidates)
|
| 98 |
+
restaurant_ids = rerank(candidates_idx, n_rec, data_source)
|
| 99 |
return restaurant_ids
|