Spaces:
Running
Running
Upload 2 files
Browse files
app.py
CHANGED
|
@@ -108,6 +108,7 @@ with gr.Blocks(
|
|
| 108 |
data_source = gr.Dropdown(
|
| 109 |
choices=["Michelin Guide", "Google", "Yelp"],
|
| 110 |
value="Yelp",
|
|
|
|
| 111 |
label="Data Source",
|
| 112 |
info="Select restaurant data source"
|
| 113 |
)
|
|
@@ -142,10 +143,10 @@ with gr.Blocks(
|
|
| 142 |
|
| 143 |
examples = [
|
| 144 |
["Italian pasta", "Yelp", 10],
|
| 145 |
-
["sushi", "Michelin", 10],
|
| 146 |
["romantic dinner", "Google", 8],
|
| 147 |
["family-friendly pizza", "Yelp", 10],
|
| 148 |
-
["best seafood", "Michelin", 10],
|
| 149 |
["cheap burger", "Google", 10]
|
| 150 |
]
|
| 151 |
|
|
@@ -172,7 +173,7 @@ if __name__ == "__main__":
|
|
| 172 |
print("Opening at http://127.0.0.1:7860\n")
|
| 173 |
|
| 174 |
# if run locally
|
| 175 |
-
|
| 176 |
|
| 177 |
-
# if run on HF Space
|
| 178 |
-
app.launch(ssr_mode=False)
|
|
|
|
| 108 |
data_source = gr.Dropdown(
|
| 109 |
choices=["Michelin Guide", "Google", "Yelp"],
|
| 110 |
value="Yelp",
|
| 111 |
+
multiselect=True,
|
| 112 |
label="Data Source",
|
| 113 |
info="Select restaurant data source"
|
| 114 |
)
|
|
|
|
| 143 |
|
| 144 |
examples = [
|
| 145 |
["Italian pasta", "Yelp", 10],
|
| 146 |
+
["sushi", "Michelin Guide", 10],
|
| 147 |
["romantic dinner", "Google", 8],
|
| 148 |
["family-friendly pizza", "Yelp", 10],
|
| 149 |
+
["best seafood", "Michelin Guide", 10],
|
| 150 |
["cheap burger", "Google", 10]
|
| 151 |
]
|
| 152 |
|
|
|
|
| 173 |
print("Opening at http://127.0.0.1:7860\n")
|
| 174 |
|
| 175 |
# if run locally
|
| 176 |
+
app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True)
|
| 177 |
|
| 178 |
+
# # if run on HF Space
|
| 179 |
+
# app.launch(ssr_mode=False)
|
main.py
CHANGED
|
@@ -12,11 +12,7 @@ from utils.semantic_similarity import Encoder
|
|
| 12 |
from utils.syntactic_similarity import Parser
|
| 13 |
from utils.tfidf_similarity import TFIDF_Vectorizer
|
| 14 |
|
| 15 |
-
|
| 16 |
-
if torch.cuda.is_available():
|
| 17 |
-
torch.set_default_device("cuda")
|
| 18 |
-
else:
|
| 19 |
-
torch.set_default_device("cpu")
|
| 20 |
|
| 21 |
# Download models/data
|
| 22 |
nltk.download('punkt')
|
|
@@ -30,9 +26,7 @@ data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
|
|
| 30 |
with open("data/restaurant_by_source.json", "r") as f:
|
| 31 |
restaurant_by_source = json.load(f)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
# restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
|
| 35 |
-
|
| 36 |
print("Computing TFIDF")
|
| 37 |
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
|
| 38 |
restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
|
|
@@ -91,7 +85,7 @@ def retrieve_candidates(query: str, n_candidates: int):
|
|
| 91 |
return candidates_idx
|
| 92 |
|
| 93 |
|
| 94 |
-
def rerank(candidates_idx: np.ndarray, n_rec: int
|
| 95 |
print("Reranking...")
|
| 96 |
|
| 97 |
# Get popularity scores for stage 1 candidates
|
|
@@ -105,15 +99,18 @@ def rerank(candidates_idx: np.ndarray, n_rec: int = 10, data_source: str = None)
|
|
| 105 |
restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
|
| 106 |
|
| 107 |
# Filter to only data_source
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
print(f"[RERANK] Final recommendations: {restaurant_ids}")
|
| 113 |
return restaurant_ids
|
| 114 |
|
| 115 |
-
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30,
|
| 116 |
query_clean = clean_text(query)
|
| 117 |
candidates_idx = retrieve_candidates(query_clean, n_candidates)
|
| 118 |
-
restaurant_ids = rerank(candidates_idx, n_rec,
|
| 119 |
return restaurant_ids
|
|
|
|
| 12 |
from utils.syntactic_similarity import Parser
|
| 13 |
from utils.tfidf_similarity import TFIDF_Vectorizer
|
| 14 |
|
| 15 |
+
torch.set_default_device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Download models/data
|
| 18 |
nltk.download('punkt')
|
|
|
|
| 26 |
with open("data/restaurant_by_source.json", "r") as f:
|
| 27 |
restaurant_by_source = json.load(f)
|
| 28 |
|
| 29 |
+
# Compute TFIDF features
|
|
|
|
|
|
|
| 30 |
print("Computing TFIDF")
|
| 31 |
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
|
| 32 |
restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
|
|
|
|
| 85 |
return candidates_idx
|
| 86 |
|
| 87 |
|
| 88 |
+
def rerank(candidates_idx: np.ndarray, n_rec: int, data_sources: list = None) -> list:
|
| 89 |
print("Reranking...")
|
| 90 |
|
| 91 |
# Get popularity scores for stage 1 candidates
|
|
|
|
| 99 |
restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
|
| 100 |
|
| 101 |
# Filter to only data_source
|
| 102 |
+
if data_sources is not None:
|
| 103 |
+
print(f"[RERANK] Filtering to only source - {data_sources}")
|
| 104 |
+
restaurant_by_source_set = set()
|
| 105 |
+
for src in data_sources:
|
| 106 |
+
restaurant_by_source_set.update(restaurant_by_source[src])
|
| 107 |
+
restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
|
| 108 |
|
| 109 |
print(f"[RERANK] Final recommendations: {restaurant_ids}")
|
| 110 |
return restaurant_ids
|
| 111 |
|
| 112 |
+
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_sources: list = None):
|
| 113 |
query_clean = clean_text(query)
|
| 114 |
candidates_idx = retrieve_candidates(query_clean, n_candidates)
|
| 115 |
+
restaurant_ids = rerank(candidates_idx, n_rec, data_sources)
|
| 116 |
return restaurant_ids
|