team-149-project / main.py
knguyen471's picture
Upload 2 files
ce42873 verified
import json
import torch
import nltk
import benepar
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from utils.clean_text import clean_text
from utils.semantic_similarity import Encoder
from utils.syntactic_similarity import Parser
from utils.tfidf_similarity import TFIDF_Vectorizer
torch.set_default_device("cpu")
# Download models/data
nltk.download('punkt')
nltk.download('punkt_tab')
benepar.download('benepar_en3_large')
# Load dataset
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
# Load restaurant_by_source
with open("data/restaurant_by_source.json", "r") as f:
restaurant_by_source = json.load(f)
# Compute TFIDF features
print("Computing TFIDF")
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
# Extract embeddings
data["embedding"] = data["embedding"].apply(
lambda x: np.fromstring(x.strip('[]'), sep=' ')
)
all_desc_embeddings = np.vstack(data["embedding"].values)
# Initialize encoder
encoder = Encoder()
# Initialize syntactic parser
parser = Parser()
def retrieve_candidates(query: str, n_candidates: int):
print(f"Retrieving {n_candidates} candidates...")
# Encode query
print("[RETRIEVAL] Encoding query")
query_emb = encoder.encode([query]).cpu().numpy()
# Semantic similarities
print("[RETRIEVAL] Computing semantic similarities")
desc_sem_sim = cosine_similarity(query_emb, all_desc_embeddings)[0]
# TF-IDF similarities
print("[RETRIEVAL] Computing TF-IDF")
tfidf_sim = tfidf_vectorizer.compute_tfidf_scores(query, restaurant_tfidf_features)
# Syntactic similarities
print("[RETRIEVAL] Computing syntactic similarities")
parsed_query = parser.parse_text(query)
parsed_query = parser.subtree_set(parsed_query)
syn_sims = []
for trees_list in tqdm(data["syntactic_tree"], total=len(data), desc="[RETRIEVAL] Computing syntactic similarities"):
review_sims = []
for review_tree_subs in trees_list:
if review_tree_subs is None:
review_tree_subs = set()
sim = parser.compute_syntactic_similarity(parsed_query, review_tree_subs)
review_sims.append(sim)
syn_sims.append(np.mean(review_sims))
# Combined Stage 1 score
syn_sims = np.array(syn_sims)
combined_stage1_scores = 0.8*desc_sem_sim + 0.1*syn_sims + 0.1*tfidf_sim
# Get top N candidates for Stage 2 reranking
candidates_idx = np.argsort(combined_stage1_scores)[-n_candidates:][::-1]
print(f"[RETRIEVAL] Results: {candidates_idx}")
return candidates_idx
def rerank(candidates_idx: np.ndarray, n_rec: int, data_sources: list = None) -> list:
print("Reranking...")
# Get popularity scores for stage 1 candidates
rerank_scores = data.loc[candidates_idx, "pop_score"].values
# Retrieve n_rec restaurant based on pop_score
topN_reranked_local_idx = np.argsort(rerank_scores)[-n_rec:][::-1]
topN_reranked_global_idx = candidates_idx[topN_reranked_local_idx]
# Get restaurant_id for final recommendations
restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
# Filter to only data_source
if data_sources is not None:
print(f"[RERANK] Filtering to only source - {data_sources}")
restaurant_by_source_set = set()
for src in data_sources:
restaurant_by_source_set.update(restaurant_by_source[src])
restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
print(f"[RERANK] Final recommendations: {restaurant_ids}")
return restaurant_ids
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_sources: list = None):
query_clean = clean_text(query)
candidates_idx = retrieve_candidates(query_clean, n_candidates)
restaurant_ids = rerank(candidates_idx, n_rec, data_sources)
return restaurant_ids