knguyen471 commited on
Commit
812c65f
·
verified ·
1 Parent(s): 65c08a9

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -7
main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import nltk
3
  import benepar
@@ -24,9 +25,9 @@ benepar.download('benepar_en3_large')
24
  # Load dataset
25
  data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
26
 
27
- # Group by source
28
- restaurant_by_source = {k: v["id"].tolist() for k, v in data.groupby("source")}
29
- restaurant_by_source
30
 
31
  # Load precomputed TF-IDF features
32
  restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
@@ -77,7 +78,7 @@ def retrieve_candidates(query: str, n_candidates: int):
77
  return candidates_idx
78
 
79
 
80
- def rerank(candidates_idx: np.ndarray, n_rec: int = 10, ) -> list:
81
 
82
  # Get popularity scores for stage 1 candidates
83
  rerank_scores = data.loc[candidates_idx, "pop_score"].values
@@ -91,8 +92,8 @@ def rerank(candidates_idx: np.ndarray, n_rec: int = 10, ) -> list:
91
 
92
  return restaurant_ids
93
 
94
- def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30):
95
  query_clean = clean_text(query)
96
- candidates_idx = retrieve_candidates(query_clean, n_candidates)
97
- restaurant_ids = rerank(candidates_idx, n_rec)
98
  return restaurant_ids
 
1
+ import json
2
  import torch
3
  import nltk
4
  import benepar
 
25
  # Load dataset
26
  data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
27
 
28
+ # Load restaurant_by_source
29
+ with open("data/restaurant_by_source.json", "r") as f:
30
+ restaurant_by_source = json.load(f)
31
 
32
  # Load precomputed TF-IDF features
33
  restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
 
78
  return candidates_idx
79
 
80
 
81
+ def rerank(candidates_idx: np.ndarray, n_rec: int = 10, data_source: str) -> list:
82
 
83
  # Get popularity scores for stage 1 candidates
84
  rerank_scores = data.loc[candidates_idx, "pop_score"].values
 
92
 
93
  return restaurant_ids
94
 
95
+ def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_source: str = None):
96
  query_clean = clean_text(query)
97
+ candidates_idx = retrieve_candidates(query_clean, n_candidates)
98
+ restaurant_ids = rerank(candidates_idx, n_rec, data_source)
99
  return restaurant_ids