Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import folium | |
| import sys | |
| import os | |
| # Add utils to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'utils')) | |
| from clean_text import clean_text | |
| from semantic_similarity import Encoder | |
| from ranker import compute_bayesian_popularity_score | |
| from main import get_recommendations | |
| print("Loading restaurant data...") | |
| data = pd.read_csv("../data/toy_data_aggregated_embeddings.csv") | |
| print(f"Loaded {len(data)} restaurants") | |
| # Compute Bayesian popularity scores | |
| print("Computing popularity scores...") | |
| data = compute_bayesian_popularity_score(data) | |
| print("Popularity scores computed") | |
| print("Loading pre-computed embeddings...") | |
| all_desc_embeddings = np.vstack(data["embedding"].values) | |
| print(f"Loaded embeddings with shape {all_desc_embeddings.shape}") | |
| # Initialize semantic encoder | |
| print("Loading semantic encoder model...") | |
| try: | |
| encoder = Encoder() | |
| print("Semantic encoder loaded") | |
| except Exception as e: | |
| print(f"Warning: Could not load semantic encoder: {e}") | |
| print("Falling back to keyword-only search") | |
| def create_paris_map(results_df): | |
| """Create interactive map of Paris restaurants""" | |
| paris_center = [48.8566, 2.3522] | |
| m = folium.Map(location=paris_center, zoom_start=12, tiles='OpenStreetMap') | |
| for idx, row in results_df.iterrows(): | |
| lat_offset = np.random.uniform(-0.05, 0.05) | |
| lng_offset = np.random.uniform(-0.07, 0.07) | |
| coords = [48.8566 + lat_offset, 2.3522 + lng_offset] | |
| rating = float(row.get('overall_rating', 0)) | |
| color = 'green' if rating >= 4.5 else 'blue' if rating >= 4.0 else 'orange' if rating >= 3.5 else 'red' | |
| popup_html = f""" | |
| <div style="width:250px"> | |
| <h4><b>{row['name']}</b></h4> | |
| <p>Rating: {row.get('overall_rating', 'N/A')}</p> | |
| <p>Reviews: {row.get('review_count', 'N/A')}</p> | |
| <p>Popularity Score: {row.get('pop_score', 'N/A'):.2f}</p> | |
| </div> | |
| """ | |
| folium.Marker( | |
| location=coords, | |
| popup=folium.Popup(popup_html, max_width=300), | |
| icon=folium.Icon(color=color, icon='cutlery', prefix='fa') | |
| ).add_to(m) | |
| return m._repr_html_() | |
| # def semantic_search(query, data_source, num_results, use_popularity): | |
| # """Semantic search using embeddings""" | |
| # if not query.strip(): | |
| # return "Please enter a search query", None | |
| # try: | |
| # query_clean = clean_text(query) | |
| # # Generate query embedding | |
| # print(f"Encoding query: {query_clean}") | |
| # query_embedding = encoder.encode([query_clean], show_progress_bar=False) | |
| # query_embedding = query_embedding.cpu().numpy() | |
| # # Compute semantic similarity | |
| # similarities = cosine_similarity(query_embedding, all_desc_embeddings)[0] | |
| # # Combine with popularity if requested | |
| # if use_popularity: | |
| # sim_normalized = (similarities - similarities.min()) / (similarities.max() - similarities.min() + 1e-10) | |
| # pop_normalized = (data["pop_score"] - data["pop_score"].min()) / (data["pop_score"].max() - data["pop_score"].min() + 1e-10) | |
| # # Combined score: 70% semantic, 30% popularity | |
| # scores = 0.7 * sim_normalized + 0.3 * pop_normalized | |
| # else: | |
| # scores = similarities | |
| # top_indices = np.argsort(scores)[-int(num_results):][::-1] | |
| # results = data.iloc[top_indices].copy() | |
| # results['similarity_score'] = scores[top_indices] | |
| # map_html = create_paris_map(results) | |
| # output = f"Found {len(results)} restaurants for '{query}'\n" | |
| # output += f"Data Source: {data_source}\n" | |
| # output += f"Search Method: Semantic Search {'+ Popularity' if use_popularity else ''}\n\n" | |
| # for idx, (_, row) in enumerate(results.iterrows(), 1): | |
| # name = row.get('name', 'Unknown') | |
| # rating = row.get('overall_rating', 'N/A') | |
| # reviews = row.get('review_count', 'N/A') | |
| # similarity = row.get('similarity_score', 0) | |
| # pop_score = row.get('pop_score', 0) | |
| # output += f"{idx}. **{name}**\n" | |
| # output += f" Rating: {rating} | Reviews: {reviews}\n" | |
| # output += f" Match: {similarity:.3f}" | |
| # if use_popularity: | |
| # output += f" | Popularity: {pop_score:.2f}" | |
| # output += "\n" | |
| # if 'address' in row and pd.notna(row['address']): | |
| # addr = str(row['address'])[:100] | |
| # output += f" Address: {addr}\n" | |
| # output += "\n" | |
| # return output, map_html | |
| # except Exception as e: | |
| # import traceback | |
| # return f"Error: {str(e)}\n\n{traceback.format_exc()}", None | |
| # def keyword_search(query, data_source, num_results, use_popularity): | |
| # """Keyword-based search with optional popularity ranking""" | |
| # if not query.strip(): | |
| # return "Please enter a search query", None | |
| # try: | |
| # query_clean = clean_text(query).lower() | |
| # query_words = set(query_clean.split()) | |
| # scores = [] | |
| # for idx, row in data.iterrows(): | |
| # score = 0 | |
| # name = str(row.get('name', '')).lower() | |
| # # Check name matches | |
| # for word in query_words: | |
| # if word in name: | |
| # score += 2 | |
| # rating = float(row.get('overall_rating', 0)) | |
| # score += rating * 0.5 | |
| # # Add popularity if requested | |
| # if use_popularity: | |
| # pop_score = float(row.get('pop_score', 0)) | |
| # score += pop_score * 0.3 | |
| # scores.append(score) | |
| # top_indices = np.argsort(scores)[-int(num_results):][::-1] | |
| # results = data.iloc[top_indices].copy() | |
| # results['match_score'] = [scores[i] for i in top_indices] | |
| # map_html = create_paris_map(results) | |
| # output = f"Found {len(results)} restaurants for '{query}'\n" | |
| # output += f"Data Source: {data_source}\n" | |
| # output += f"Search Method: Keyword Search {'+ Popularity' if use_popularity else ''}\n\n" | |
| # for idx, (_, row) in enumerate(results.iterrows(), 1): | |
| # name = row.get('name', 'Unknown') | |
| # rating = row.get('overall_rating', 'N/A') | |
| # reviews = row.get('review_count', 'N/A') | |
| # match = row.get('match_score', 0) | |
| # pop_score = row.get('pop_score', 0) | |
| # output += f"{idx}. **{name}**\n" | |
| # output += f" Rating: {rating} | Reviews: {reviews}\n" | |
| # output += f" Match Score: {match:.2f}" | |
| # if use_popularity: | |
| # output += f" | Popularity: {pop_score:.2f}" | |
| # output += "\n" | |
| # if 'address' in row and pd.notna(row['address']): | |
| # addr = str(row['address'])[:100] | |
| # output += f" Address: {addr}\n" | |
| # output += "\n" | |
| # return output, map_html | |
| # except Exception as e: | |
| # import traceback | |
| # return f"Error: {str(e)}\n\n{traceback.format_exc()}", None | |
| # def search_restaurants(query, data_source, search_method, num_results, use_popularity): | |
| # """Main search function that routes to appropriate search method""" | |
| # if search_method == "Semantic Search" and use_semantic: | |
| # return semantic_search(query, data_source, num_results, use_popularity) | |
| # else: | |
| # return keyword_search(query, data_source, num_results, use_popularity) | |
| def search_restaurants(query_input, data_source, num_results): | |
| n_candidates = 100 | |
| query_clean = clean_text(query_input) | |
| return get_recommendations(query_clean, n_candidates, num_results) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Restaurant Finder", theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # Advanced Restaurant Recommendation System | |
| ### Search Through 5,000+ Restaurants with AI-Powered Semantic Search | |
| Find restaurants using semantic understanding and popularity ranking! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Search Query", | |
| placeholder="e.g., Italian pasta, best sushi, romantic dinner, family-friendly pizza", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=2): | |
| data_source = gr.Dropdown( | |
| choices=["Michelin", "Google", "Yelp"], | |
| value="Yelp", | |
| label="Data Source", | |
| info="Select restaurant data source" | |
| ) | |
| with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # search_method = gr.Radio( | |
| # choices=["Keyword Search", "Semantic Search"], | |
| # value="Semantic Search" if use_semantic else "Keyword Search", | |
| # label="Search Method", | |
| # info="Semantic uses AI embeddings, Keyword uses exact matches" | |
| # ) | |
| with gr.Column(scale=1): | |
| num_results = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| value=10, | |
| step=5, | |
| label="Results" | |
| ) | |
| # with gr.Column(scale=1): | |
| # use_popularity = gr.Checkbox( | |
| # label="Use Popularity Ranking", | |
| # value=True, | |
| # info="Boost popular restaurants" | |
| # ) | |
| search_btn = gr.Button("Search Restaurants", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| results_output = gr.Textbox( | |
| label="Restaurant Results", | |
| lines=20, | |
| max_lines=30 | |
| ) | |
| with gr.Column(scale=1): | |
| map_output = gr.HTML( | |
| label="Paris Map" | |
| ) | |
| gr.Markdown("### Try These Examples:") | |
| examples = [ | |
| ["Italian pasta", "Yelp", 10], | |
| ["sushi", "Michelin", 10], | |
| ["romantic dinner", "Google", 8], | |
| ["family-friendly pizza", "Yelp", 10], | |
| ["best seafood", "Michelin", 10], | |
| ["cheap burger", "Google", 10] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[query_input, data_source, num_results] | |
| ) | |
| search_btn.click( | |
| fn=search_restaurants, | |
| inputs=[query_input, data_source, num_results], | |
| outputs=[results_output, map_output] | |
| ) | |
| query_input.submit( | |
| fn=search_restaurants, | |
| inputs=[query_input, data_source, num_results], | |
| outputs=[results_output, map_output] | |
| ) | |
| if __name__ == "__main__": | |
| print("\nStarting Advanced Restaurant Finder...") | |
| print(f"{len(data)} restaurants ready to search") | |
| print(f"Popularity Ranking: Enabled") | |
| print("Opening at http://127.0.0.1:7860\n") | |
| app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True) | |