knguyen471's picture
Upload 11 files
888aba6 verified
raw
history blame
11.3 kB
import gradio as gr
import pandas as pd
import numpy as np
import folium
import sys
import os
# Add utils to path
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'utils'))
from clean_text import clean_text
from semantic_similarity import Encoder
from ranker import compute_bayesian_popularity_score
from main import get_recommendations
print("Loading restaurant data...")
data = pd.read_csv("../data/toy_data_aggregated_embeddings.csv")
print(f"Loaded {len(data)} restaurants")
# Compute Bayesian popularity scores
print("Computing popularity scores...")
data = compute_bayesian_popularity_score(data)
print("Popularity scores computed")
print("Loading pre-computed embeddings...")
all_desc_embeddings = np.vstack(data["embedding"].values)
print(f"Loaded embeddings with shape {all_desc_embeddings.shape}")
# Initialize semantic encoder
print("Loading semantic encoder model...")
try:
encoder = Encoder()
print("Semantic encoder loaded")
except Exception as e:
print(f"Warning: Could not load semantic encoder: {e}")
print("Falling back to keyword-only search")
def create_paris_map(results_df):
"""Create interactive map of Paris restaurants"""
paris_center = [48.8566, 2.3522]
m = folium.Map(location=paris_center, zoom_start=12, tiles='OpenStreetMap')
for idx, row in results_df.iterrows():
lat_offset = np.random.uniform(-0.05, 0.05)
lng_offset = np.random.uniform(-0.07, 0.07)
coords = [48.8566 + lat_offset, 2.3522 + lng_offset]
rating = float(row.get('overall_rating', 0))
color = 'green' if rating >= 4.5 else 'blue' if rating >= 4.0 else 'orange' if rating >= 3.5 else 'red'
popup_html = f"""
<div style="width:250px">
<h4><b>{row['name']}</b></h4>
<p>Rating: {row.get('overall_rating', 'N/A')}</p>
<p>Reviews: {row.get('review_count', 'N/A')}</p>
<p>Popularity Score: {row.get('pop_score', 'N/A'):.2f}</p>
</div>
"""
folium.Marker(
location=coords,
popup=folium.Popup(popup_html, max_width=300),
icon=folium.Icon(color=color, icon='cutlery', prefix='fa')
).add_to(m)
return m._repr_html_()
# def semantic_search(query, data_source, num_results, use_popularity):
# """Semantic search using embeddings"""
# if not query.strip():
# return "Please enter a search query", None
# try:
# query_clean = clean_text(query)
# # Generate query embedding
# print(f"Encoding query: {query_clean}")
# query_embedding = encoder.encode([query_clean], show_progress_bar=False)
# query_embedding = query_embedding.cpu().numpy()
# # Compute semantic similarity
# similarities = cosine_similarity(query_embedding, all_desc_embeddings)[0]
# # Combine with popularity if requested
# if use_popularity:
# sim_normalized = (similarities - similarities.min()) / (similarities.max() - similarities.min() + 1e-10)
# pop_normalized = (data["pop_score"] - data["pop_score"].min()) / (data["pop_score"].max() - data["pop_score"].min() + 1e-10)
# # Combined score: 70% semantic, 30% popularity
# scores = 0.7 * sim_normalized + 0.3 * pop_normalized
# else:
# scores = similarities
# top_indices = np.argsort(scores)[-int(num_results):][::-1]
# results = data.iloc[top_indices].copy()
# results['similarity_score'] = scores[top_indices]
# map_html = create_paris_map(results)
# output = f"Found {len(results)} restaurants for '{query}'\n"
# output += f"Data Source: {data_source}\n"
# output += f"Search Method: Semantic Search {'+ Popularity' if use_popularity else ''}\n\n"
# for idx, (_, row) in enumerate(results.iterrows(), 1):
# name = row.get('name', 'Unknown')
# rating = row.get('overall_rating', 'N/A')
# reviews = row.get('review_count', 'N/A')
# similarity = row.get('similarity_score', 0)
# pop_score = row.get('pop_score', 0)
# output += f"{idx}. **{name}**\n"
# output += f" Rating: {rating} | Reviews: {reviews}\n"
# output += f" Match: {similarity:.3f}"
# if use_popularity:
# output += f" | Popularity: {pop_score:.2f}"
# output += "\n"
# if 'address' in row and pd.notna(row['address']):
# addr = str(row['address'])[:100]
# output += f" Address: {addr}\n"
# output += "\n"
# return output, map_html
# except Exception as e:
# import traceback
# return f"Error: {str(e)}\n\n{traceback.format_exc()}", None
# def keyword_search(query, data_source, num_results, use_popularity):
# """Keyword-based search with optional popularity ranking"""
# if not query.strip():
# return "Please enter a search query", None
# try:
# query_clean = clean_text(query).lower()
# query_words = set(query_clean.split())
# scores = []
# for idx, row in data.iterrows():
# score = 0
# name = str(row.get('name', '')).lower()
# # Check name matches
# for word in query_words:
# if word in name:
# score += 2
# rating = float(row.get('overall_rating', 0))
# score += rating * 0.5
# # Add popularity if requested
# if use_popularity:
# pop_score = float(row.get('pop_score', 0))
# score += pop_score * 0.3
# scores.append(score)
# top_indices = np.argsort(scores)[-int(num_results):][::-1]
# results = data.iloc[top_indices].copy()
# results['match_score'] = [scores[i] for i in top_indices]
# map_html = create_paris_map(results)
# output = f"Found {len(results)} restaurants for '{query}'\n"
# output += f"Data Source: {data_source}\n"
# output += f"Search Method: Keyword Search {'+ Popularity' if use_popularity else ''}\n\n"
# for idx, (_, row) in enumerate(results.iterrows(), 1):
# name = row.get('name', 'Unknown')
# rating = row.get('overall_rating', 'N/A')
# reviews = row.get('review_count', 'N/A')
# match = row.get('match_score', 0)
# pop_score = row.get('pop_score', 0)
# output += f"{idx}. **{name}**\n"
# output += f" Rating: {rating} | Reviews: {reviews}\n"
# output += f" Match Score: {match:.2f}"
# if use_popularity:
# output += f" | Popularity: {pop_score:.2f}"
# output += "\n"
# if 'address' in row and pd.notna(row['address']):
# addr = str(row['address'])[:100]
# output += f" Address: {addr}\n"
# output += "\n"
# return output, map_html
# except Exception as e:
# import traceback
# return f"Error: {str(e)}\n\n{traceback.format_exc()}", None
# def search_restaurants(query, data_source, search_method, num_results, use_popularity):
# """Main search function that routes to appropriate search method"""
# if search_method == "Semantic Search" and use_semantic:
# return semantic_search(query, data_source, num_results, use_popularity)
# else:
# return keyword_search(query, data_source, num_results, use_popularity)
def search_restaurants(query_input, data_source, num_results):
n_candidates = 100
query_clean = clean_text(query_input)
return get_recommendations(query_clean, n_candidates, num_results)
# Create Gradio interface
with gr.Blocks(title="Restaurant Finder", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# Advanced Restaurant Recommendation System
### Search Through 5,000+ Restaurants with AI-Powered Semantic Search
Find restaurants using semantic understanding and popularity ranking!
""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Search Query",
placeholder="e.g., Italian pasta, best sushi, romantic dinner, family-friendly pizza",
lines=2
)
with gr.Column(scale=2):
data_source = gr.Dropdown(
choices=["Michelin", "Google", "Yelp"],
value="Yelp",
label="Data Source",
info="Select restaurant data source"
)
with gr.Row():
# with gr.Column(scale=2):
# search_method = gr.Radio(
# choices=["Keyword Search", "Semantic Search"],
# value="Semantic Search" if use_semantic else "Keyword Search",
# label="Search Method",
# info="Semantic uses AI embeddings, Keyword uses exact matches"
# )
with gr.Column(scale=1):
num_results = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=5,
label="Results"
)
# with gr.Column(scale=1):
# use_popularity = gr.Checkbox(
# label="Use Popularity Ranking",
# value=True,
# info="Boost popular restaurants"
# )
search_btn = gr.Button("Search Restaurants", variant="primary", size="lg")
with gr.Row():
with gr.Column(scale=1):
results_output = gr.Textbox(
label="Restaurant Results",
lines=20,
max_lines=30
)
with gr.Column(scale=1):
map_output = gr.HTML(
label="Paris Map"
)
gr.Markdown("### Try These Examples:")
examples = [
["Italian pasta", "Yelp", 10],
["sushi", "Michelin", 10],
["romantic dinner", "Google", 8],
["family-friendly pizza", "Yelp", 10],
["best seafood", "Michelin", 10],
["cheap burger", "Google", 10]
]
gr.Examples(
examples=examples,
inputs=[query_input, data_source, num_results]
)
search_btn.click(
fn=search_restaurants,
inputs=[query_input, data_source, num_results],
outputs=[results_output, map_output]
)
query_input.submit(
fn=search_restaurants,
inputs=[query_input, data_source, num_results],
outputs=[results_output, map_output]
)
if __name__ == "__main__":
print("\nStarting Advanced Restaurant Finder...")
print(f"{len(data)} restaurants ready to search")
print(f"Popularity Ranking: Enabled")
print("Opening at http://127.0.0.1:7860\n")
app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True)